Compare commits

..

78 Commits

Author SHA1 Message Date
comfyanonymous
8e4118c0de make dpm_2_ancestral work with rectified flow. 2024-12-01 07:37:41 -05:00
comfyanonymous
3fc6ebcdd7 Add basic style model "multiply" strength. 2024-11-30 07:27:11 -05:00
comfyanonymous
20a560eb97 How to enable experimental memory efficient attention on ROCm RDNA3. 2024-11-29 06:19:49 -05:00
Dr.Lt.Data
82c5308561 Backward compatibility patch for changes in the method signature of InpaintModelConditioning. (#5825)
https://github.com/comfyanonymous/ComfyUI/issues/5813
2024-11-28 20:30:28 -05:00
comfyanonymous
26fb2c68e8 Add a way to disable cropping in the CLIPVisionEncode node. 2024-11-28 20:24:47 -05:00
comfyanonymous
bf2650a80e Fast previews for ltxv. 2024-11-28 06:46:15 -05:00
Chenlei Hu
53646e0f32 Update web content to release v1.4.13 (#5807) 2024-11-28 04:59:06 -05:00
Chenlei Hu
20879c78f9 Remove internal model download endpoint (#5432) 2024-11-28 04:57:06 -05:00
comfyanonymous
b666539595 Remove print. 2024-11-27 20:28:39 -05:00
comfyanonymous
95d8713482 Missing parentheses. 2024-11-27 13:45:32 -05:00
comfyanonymous
0d4e29f13f LTXV model merging node. 2024-11-27 01:43:31 -05:00
comfyanonymous
497db6212f Alternative fix for #5767 2024-11-26 17:53:04 -05:00
lky
24dc581dc3 fix multi add makedirs error (#5786)
try to start multiple comfyui server at the same time, and this got error
2024-11-26 15:34:19 -05:00
comfyanonymous
4c82741b54 Support official SD3.5 Controlnets. 2024-11-26 11:31:25 -05:00
comfyanonymous
15c39ea757 Support for the official mochi lora format. 2024-11-26 03:34:36 -05:00
comfyanonymous
b7143b74ce Flux inpaint model does not work in fp16. 2024-11-26 01:33:01 -05:00
comfyanonymous
61196d8857 Add option to inference the diffusion model in fp32 and fp64. 2024-11-25 05:00:23 -05:00
comfyanonymous
b4526d3fc3 Skip layer guidance now works on hydit model. 2024-11-24 05:54:30 -05:00
40476
3d802710e7 Update README.md (#5707) 2024-11-24 04:12:07 -05:00
spacepxl
7126ecffde set LTX min length to 1 for t2i (#5750)
At length=1, the LTX model can do txt2img and img2img with no other changes required.
2024-11-23 21:33:08 -05:00
comfyanonymous
ab885b33ba Skip layer guidance node now works on LTX-Video. 2024-11-23 10:33:05 -05:00
comfyanonymous
839ed3368e Some improvements to the lowvram unloading. 2024-11-22 20:59:15 -05:00
comfyanonymous
6e8cdcd3cb Fix some tiled VAE decoding issues with LTX-Video. 2024-11-22 18:00:34 -05:00
comfyanonymous
e5c3f4b87f LTXV lowvram fixes. 2024-11-22 17:17:11 -05:00
comfyanonymous
bc6be6c11e Some fixes to the lowvram system. 2024-11-22 16:40:04 -05:00
comfyanonymous
94323a26a7 Remove prints. 2024-11-22 10:51:31 -05:00
comfyanonymous
5818f6cf51 Remove print. 2024-11-22 10:49:15 -05:00
comfyanonymous
0b734de449 Add LTX-Video support to the Readme. 2024-11-22 09:24:20 -05:00
comfyanonymous
5e16f1d24b Support Lightricks LTX-Video model. 2024-11-22 08:46:39 -05:00
comfyanonymous
2fd9c1308a Fix mask issue in some attention functions. 2024-11-22 02:10:09 -05:00
comfyanonymous
8f0009aad0 Support new flux model variants. 2024-11-21 08:38:23 -05:00
comfyanonymous
41444b5236 Add some new weight patching functionality.
Add a way to reshape lora weights.

Allow weight patches to all weight not just .weight and .bias

Add a way for a lora to set a weight to a specific value.
2024-11-21 07:19:17 -05:00
comfyanonymous
772e620e32 Update readme. 2024-11-20 20:42:51 -05:00
comfyanonymous
07f6eeaa13 Fix mask issue with attention_xformers. 2024-11-20 17:07:46 -05:00
comfyanonymous
22535d0589 Skip layer guidance now works on stable audio model. 2024-11-20 07:33:06 -05:00
comfyanonymous
898615122f Rename add_noise_mask -> noise_mask. 2024-11-19 15:31:09 -05:00
comfyanonymous
156a28786b Add boolean to InpaintModelConditioning to disable the noise mask. 2024-11-19 07:31:29 -05:00
Yoland Yan
f498d855ba Add terminal size fallback (#5623) 2024-11-19 03:34:20 -05:00
comfyanonymous
b699a15062 Refactor inpaint/ip2p code. 2024-11-19 03:25:25 -05:00
Chenlei Hu
9cc90ee3eb Update UI screenshot in README (#5666)
* Update UI ScreenShot in README

* Remove legacy UI screenshot file

* nit

* nit
2024-11-18 16:50:34 -05:00
comfyanonymous
9a0a5d32ee Add a skip layer guidance node that can also skip single layers.
This one should work for skipping the single layers of models like Flux
and Auraflow.

If you want to see how these models work and how many double/single layers
they have see the "ModelMerge*" nodes for the specific model.
2024-11-18 02:20:43 -05:00
comfyanonymous
d9f90965c8 Support block replace patches in auraflow. 2024-11-17 08:19:59 -05:00
comfyanonymous
41886af138 Add transformer options blocks replace patch to mochi. 2024-11-16 20:48:14 -05:00
Chenlei Hu
22a1d7ce78 Fix 3.8 compatibility in user_manager.py (#5645) 2024-11-16 20:42:21 -05:00
Chenlei Hu
4ac401af2b Update web content to release v1.3.44 (#5620)
* Update web content to release v1.3.44

* nit
2024-11-15 20:17:15 -05:00
comfyanonymous
5fb59c8475 Add a node to block merge auraflow models. 2024-11-15 12:47:55 -05:00
comfyanonymous
122c9ca1ce Add advanced model merging node for mochi. 2024-11-14 07:51:20 -05:00
comfyanonymous
3b9a6cf2b1 Fix issue with 3d masks. 2024-11-13 07:18:30 -05:00
comfyanonymous
3748e7ef7a Fix regression. 2024-11-13 04:24:48 -05:00
comfyanonymous
8ebf2d8831 Add block replace transformer_options to flux. 2024-11-12 08:00:39 -05:00
Bratzmeister
a72d152b0c fix --cuda-device arg for AMD/HIP devices (#5586)
* fix --cuda-device arg for AMD/HIP devices

CUDA_VISIBLE_DEVICES is ignored for HIP devices/backend. Instead it uses HIP_VISIBLE_DEVICES. Setting this environment variable has no side effect for CUDA/NVIDIA so it can safely be set in any case and vice versa.

* deleted accidental if
2024-11-12 06:53:36 -05:00
comfyanonymous
eb476e6ea9 Allow 1D masks for 1D latents. 2024-11-11 14:44:52 -05:00
Dr.Lt.Data
2d28b0b479 improve: add descriptions for clip loaders (#5576) 2024-11-11 05:37:23 -05:00
comfyanonymous
8b275ce5be Support auto detecting some zsnr anime checkpoints. 2024-11-11 05:34:11 -05:00
comfyanonymous
2a18e98ccf Refactor so that zsnr can be set in the sampling_settings. 2024-11-11 04:55:56 -05:00
comfyanonymous
8a5281006f Fix some custom nodes. 2024-11-10 22:41:00 -05:00
comfyanonymous
bdeb1c171c Fast previews for mochi. 2024-11-10 03:39:35 -05:00
comfyanonymous
9c1ed58ef2 proper fix for sag. 2024-11-10 00:10:45 -05:00
comfyanonymous
8b90e50979 Properly handle and reshape masks when used on 3d latents. 2024-11-09 15:30:19 -05:00
pythongosssss
6ee066a14f Live terminal output (#5396)
* Add /logs/raw and /logs/subscribe for getting logs on frontend
Hijacks stderr/stdout to send all output data to the client on flush

* Use existing send sync method

* Fix get_logs should return string

* Fix bug

* pass no server

* fix tests

* Fix output flush on linux
2024-11-08 19:13:34 -05:00
DenOfEquity
dd5b57e3d7 fix for SAG with Kohya HRFix/ Deep Shrink (#5546)
now works with arbitrary downscale factors
2024-11-08 18:16:29 -05:00
comfyanonymous
75a818c720 Move mochi latent node to: latent/video. 2024-11-08 08:33:44 -05:00
comfyanonymous
2865f913f7 Free memory before doing tiled decode. 2024-11-07 04:01:24 -05:00
comfyanonymous
b49616f951 Make VAEDecodeTiled node work with video VAEs. 2024-11-07 03:47:12 -05:00
comfyanonymous
5e29e7a488 Remove scaled_fp8 key after reading it to silence warning. 2024-11-06 04:56:42 -05:00
comfyanonymous
8afb97cd3f Fix unknown VAE being detected as the mochi VAE. 2024-11-05 03:43:27 -05:00
contentis
69694f40b3 fix dynamic shape export (#5490) 2024-11-04 14:59:28 -05:00
Chenlei Hu
c49025f01b Allow POST /userdata/{file} endpoint to return full file info (#5446)
* Refactor listuserdata

* Full info param

* Add tests

* Fix mock

* Add full_info support for move user file
2024-11-04 13:57:21 -05:00
comfyanonymous
696672905f Add mochi support to readme. 2024-11-04 04:55:07 -05:00
comfyanonymous
6c9dbde7de Fix mochi all in one checkpoint t5xxl key names. 2024-11-03 01:40:42 -05:00
comfyanonymous
ee8abf0cff Update folder paths: "clip" -> "text_encoders"
You can still use models/clip but the folder might get removed eventually
on new installs of ComfyUI.
2024-11-02 15:35:38 -04:00
comfyanonymous
fabf449feb Mochi VAE encoder. 2024-11-01 17:33:09 -04:00
Uriel Deveaud
cc9cf6d1bd Rename some nodes in Display Name Mappings (nodes.py) (#5439)
* Update nodes_images.py

Nodes menu has inconsistency in names, some with spaces between words, other not.

* Update nodes.py

Include the node mapping name line for Image Crop Node

* Update nodes_images.py

* Rename image nodes

add space between words for consistency > Display name mappings
2024-10-31 15:18:05 -04:00
Aarni Koskela
1c8286a44b Avoid SyntaxWarning in UniPC docstring (#5442) 2024-10-31 15:17:26 -04:00
comfyanonymous
1af4a47fd1 Bump up mac version for attention upcast bug workaround. 2024-10-31 15:15:31 -04:00
Uriel Deveaud
f2aaa0a475 Rename ImageCrop to Image Crop (#5424)
* Update nodes_images.py

Nodes menu has inconsistency in names, some with spaces between words, other not.

* Update nodes.py

Include the node mapping name line for Image Crop Node

* Update nodes_images.py
2024-10-31 00:35:34 -04:00
comfyanonymous
daa1565b93 Fix diffusers flux controlnet regression. 2024-10-30 13:11:34 -04:00
comfyanonymous
09fdb2b269 Support SD3.5 medium diffusers format weights and loras. 2024-10-30 04:24:00 -04:00
121 changed files with 104597 additions and 41297 deletions

View File

@@ -28,7 +28,7 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](comfyui_screenshot.png)
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
</div>
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
@@ -39,7 +39,9 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@@ -73,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| `Ctrl` + `Enter` | Queue up current graph for generation |
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| `Ctrl` + `S` | Save workflow |
| `Ctrl` + `O` | Load workflow |
| `Ctrl` + `A` | Select all nodes |
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
| `Ctrl` + `M` | Mute/unmute selected nodes |
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| `Delete`/`Backspace` | Delete selected nodes |
| `Ctrl` + `Backspace` | Delete the current graph |
| `Space` | Move the canvas around when held and moving the cursor |
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
| `Ctrl` + `D` | Load default graph |
| `Alt` + `+` | Canvas Zoom in |
| `Alt` + `-` | Canvas Zoom out |
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| `P` | Pin/Unpin selected nodes |
| `Ctrl` + `G` | Group selected nodes |
| `Q` | Toggle visibility of the queue |
| `H` | Toggle visibility of history |
| `R` | Refresh graph |
| Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
| `Shift` + Drag | Move multiple wires at once |
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users
`Ctrl` can also be replaced with `Cmd` instead for macOS users
# Installing
@@ -139,7 +141,7 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
@@ -211,6 +213,12 @@ For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
### AMD ROCm Tips
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
# Notes
Only parts of the graph that have an output with all the correct inputs will be executed.

View File

@@ -2,6 +2,7 @@ from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
from api_server.services.terminal_service import TerminalService
import app.logger
class InternalRoutes:
@@ -9,9 +10,9 @@ class InternalRoutes:
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self):
def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
@@ -19,6 +20,8 @@ class InternalRoutes:
"user": user_directory,
"output": output_directory
})
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
def setup_routes(self):
@self.routes.get('/files')
@@ -34,7 +37,28 @@ class InternalRoutes:
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@self.routes.get('/logs/raw')
async def get_logs(request):
self.terminal_service.update_size()
return web.json_response({
"entries": list(app.logger.get_logs()),
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
})
@self.routes.patch('/logs/subscribe')
async def subscribe_logs(request):
json_data = await request.json()
client_id = json_data["clientId"]
enabled = json_data["enabled"]
if enabled:
self.terminal_service.subscribe(client_id)
else:
self.terminal_service.unsubscribe(client_id)
return web.Response(status=200)
@self.routes.get('/folder_paths')
async def get_folder_paths(request):

View File

@@ -0,0 +1,60 @@
from app.logger import on_flush
import os
import shutil
class TerminalService:
def __init__(self, server):
self.server = server
self.cols = None
self.rows = None
self.subscriptions = set()
on_flush(self.send_messages)
def get_terminal_size(self):
try:
size = os.get_terminal_size()
return (size.columns, size.lines)
except OSError:
try:
size = shutil.get_terminal_size()
return (size.columns, size.lines)
except OSError:
return (80, 24) # fallback to 80x24
def update_size(self):
columns, lines = self.get_terminal_size()
changed = False
if columns != self.cols:
self.cols = columns
changed = True
if lines != self.rows:
self.rows = lines
changed = True
if changed:
return {"cols": self.cols, "rows": self.rows}
return None
def subscribe(self, client_id):
self.subscriptions.add(client_id)
def unsubscribe(self, client_id):
self.subscriptions.discard(client_id)
def send_messages(self, entries):
if not len(entries) or not len(self.subscriptions):
return
new_size = self.update_size()
for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
if client_id not in self.server.sockets:
# Automatically unsub if the socket has disconnected
self.unsubscribe(client_id)
continue
self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)

View File

@@ -1,20 +1,69 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
from datetime import datetime
import io
import logging
import sys
import threading
logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
stdout_interceptor = None
stderr_interceptor = None
class LogInterceptor(io.TextIOWrapper):
def __init__(self, stream, *args, **kwargs):
buffer = stream.buffer
encoding = stream.encoding
super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
self._lock = threading.Lock()
self._flush_callbacks = []
self._logs_since_flush = []
def write(self, data):
entry = {"t": datetime.now().isoformat(), "m": data}
with self._lock:
self._logs_since_flush.append(entry)
# Simple handling for cr to overwrite the last output if it isnt a full line
# else logs just get full of progress messages
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
logs.pop()
logs.append(entry)
super().write(data)
def flush(self):
super().flush()
for cb in self._flush_callbacks:
cb(self._logs_since_flush)
self._logs_since_flush = []
def on_flush(self, callback):
self._flush_callbacks.append(callback)
def get_logs():
return "\n".join([formatter.format(x) for x in logs])
return logs
def on_flush(callback):
if stdout_interceptor is not None:
stdout_interceptor.on_flush(callback)
if stderr_interceptor is not None:
stderr_interceptor.on_flush(callback)
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
global logs
if logs:
return
# Override output streams and log to buffer
logs = deque(maxlen=capacity)
global stdout_interceptor
global stderr_interceptor
stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(log_level)
@@ -22,10 +71,3 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)
# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)

View File

@@ -1,25 +1,42 @@
from __future__ import annotations
import json
import os
import re
import uuid
import glob
import shutil
import logging
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
import folder_paths
from .app_settings import AppSettings
from typing import TypedDict
default_user = "default"
class FileInfo(TypedDict):
path: str
size: int
modified: int
def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
"modified": os.path.getmtime(path)
}
class UserManager():
def __init__(self):
user_directory = folder_paths.get_user_directory()
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
@@ -154,6 +171,7 @@ class UserManager():
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
split_path = request.rel_url.query.get('split', '').lower() == "true"
# Use different patterns based on whether we're recursing or not
if recurse:
@@ -161,26 +179,21 @@ class UserManager():
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
def process_full_path(full_path: str) -> FileInfo | str | list[str]:
if full_info:
return get_file_info(full_path, path)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
if split_path:
return [rel_path] + rel_path.split('/')
split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path and not full_info:
results = [[x] + x.split('/') for x in results]
return rel_path
results = [
process_full_path(full_path)
for full_path in glob.glob(pattern, recursive=recurse)
if os.path.isfile(full_path)
]
return web.json_response(results)
@@ -208,20 +221,51 @@ class UserManager():
@routes.post("/userdata/{file}")
async def post_userdata(request):
"""
Upload or update a user data file.
This endpoint handles file uploads to a user's data directory, with options for
controlling overwrite behavior and response format.
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Path Parameters:
- file: The target file path (URL encoded if necessary).
Returns:
- 400: If 'file' parameter is missing.
- 403: If the requested path is not allowed.
- 409: If overwrite=false and the file already exists.
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
The request body should contain the raw file content to be written.
"""
path = get_user_data_path(request)
if not isinstance(path, str):
return path
overwrite = request.query["overwrite"] != "false"
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
if not overwrite and os.path.exists(path):
return web.Response(status=409)
return web.Response(status=409, text="File already exists")
body = await request.read()
with open(path, "wb") as f:
f.write(body)
resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(path, user_path)
else:
resp = os.path.relpath(path, user_path)
return web.json_response(resp)
@routes.delete("/userdata/{file}")
@@ -236,6 +280,30 @@ class UserManager():
@routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request):
"""
Move or rename a user data file.
This endpoint handles moving or renaming files within a user's data directory, with options for
controlling overwrite behavior and response format.
Path Parameters:
- file: The source file path (URL encoded if necessary)
- dest: The destination file path (URL encoded if necessary)
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Returns:
- 400: If either 'file' or 'dest' parameter is missing
- 403: If either requested path is not allowed
- 404: If the source file does not exist
- 409: If overwrite=false and the destination file already exists
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
"""
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source
@@ -244,12 +312,19 @@ class UserManager():
if not isinstance(source, str):
return dest
overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)
overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"
print(f"moving '{source}' -> '{dest}'")
if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")
logging.info(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(dest, user_path)
else:
resp = os.path.relpath(dest, user_path)
return web.json_response(resp)

122
comfy/cldm/dit_embedder.py Normal file
View File

@@ -0,0 +1,122 @@
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y
x = x + self.pos_embed_input(hint)
block_out = ()
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat
return {"output": block_out}

View File

@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")

View File

@@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
}
class CLIPMLP(torch.nn.Module):
@@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
num_patches = (image_size // patch_size) ** 2
if model_type == "siglip_vision_model":
self.class_embedding = None
patch_bias = True
else:
num_patches = num_patches + 1
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
patch_bias = False
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
bias=patch_bias,
dtype=dtype,
device=device
)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
if self.class_embedding is not None:
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
class CLIPVision(torch.nn.Module):
@@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
@@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :])
if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
if "projection_dim" in config_dict:
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
else:
self.visual_projection = lambda a: a
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)

View File

@@ -16,13 +16,18 @@ class Output:
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
@@ -35,6 +40,8 @@ class ClipVisionModel():
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -49,9 +56,9 @@ class ClipVisionModel():
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
@@ -94,7 +101,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")

View File

@@ -0,0 +1,13 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@@ -35,7 +35,7 @@ import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
@@ -78,6 +78,7 @@ class ControlBase:
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -129,6 +130,7 @@ class ControlBase:
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@@ -181,7 +183,7 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
super().__init__()
self.control_model = control_model
self.load_device = load_device
@@ -196,6 +198,7 @@ class ControlNet(ControlBase):
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@@ -224,6 +227,7 @@ class ControlNet(ControlBase):
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
@@ -448,6 +453,82 @@ def load_controlnet_mmdit(sd, model_options={}):
return control
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())
# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2
new_sd = {}
for k in comfy.utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
@@ -560,7 +641,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux

View File

@@ -16,7 +16,7 @@ class NoiseScheduleVP:
continuous_beta_0=0.1,
continuous_beta_1=20.,
):
"""Create a wrapper class for the forward SDE (VP type).
r"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.

View File

@@ -280,6 +280,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@@ -306,6 +309,39 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:

View File

@@ -190,7 +190,21 @@ class Mochi(LatentFormat):
0.9294154431013696, 1.3720942357788521, 0.881393668867029,
0.9168315692124348, 0.9185249279345552, 0.9274757570805041]).view(1, self.latent_channels, 1, 1, 1)
self.latent_rgb_factors = None #TODO
self.latent_rgb_factors =[
[-0.0069, -0.0045, 0.0018],
[ 0.0154, -0.0692, -0.0274],
[ 0.0333, 0.0019, 0.0206],
[-0.1390, 0.0628, 0.1678],
[-0.0725, 0.0134, -0.1898],
[ 0.0074, -0.0270, -0.0209],
[-0.0176, -0.0277, -0.0221],
[ 0.5294, 0.5204, 0.3852],
[-0.0326, -0.0446, -0.0143],
[-0.0659, 0.0153, -0.0153],
[ 0.0185, -0.0217, 0.0014],
[-0.0396, -0.0495, -0.0281]
]
self.latent_rgb_factors_bias = [-0.0940, -0.1418, -0.1453]
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
@@ -202,3 +216,139 @@ class Mochi(LatentFormat):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class LTXV(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]

View File

@@ -612,7 +612,9 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
info = {
"hidden_states": [],
@@ -643,9 +645,19 @@ class ContinuousTransformer(nn.Module):
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
blocks_replace = patches_replace.get("dit", {})
# Iterate over the transformer layers
for layer in self.layers:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
for i, layer in enumerate(self.layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:
@@ -874,7 +886,6 @@ class AudioDiffusionTransformer(nn.Module):
mask=None,
return_info=False,
control=None,
transformer_options={},
**kwargs):
return self._forward(
x,

View File

@@ -437,7 +437,8 @@ class MMDiT(nn.Module):
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
def forward(self, x, timestep, context, **kwargs):
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
# patchify x, add PE
b, c, h, w = x.shape
@@ -458,15 +459,36 @@ class MMDiT(nn.Module):
global_cond = self.t_embedder(t, x.dtype) # B, D
blocks_replace = patches_replace.get("dit", {})
if len(self.double_layers) > 0:
for layer in self.double_layers:
c, x = layer(c, x, global_cond, **kwargs)
for i, layer in enumerate(self.double_layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
cx = torch.cat([c, x], dim=1)
for layer in self.single_layers:
cx = layer(cx, global_cond, **kwargs)
for i, layer in enumerate(self.single_layers):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
x = cx[:, c_len:]

View File

@@ -2,7 +2,7 @@ import torch
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]

View File

@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
@@ -29,6 +30,7 @@ class FluxParams:
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
guidance_embed: bool
@@ -43,8 +45,9 @@ class Flux(nn.Module):
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
self.patch_size = params.patch_size
self.in_channels = params.in_channels * params.patch_size * params.patch_size
self.out_channels = params.out_channels * params.patch_size * params.patch_size
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -96,7 +99,9 @@ class Flux(nn.Module):
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -114,8 +119,19 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -127,7 +143,16 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -141,9 +166,9 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
@@ -151,10 +176,10 @@ class Flux(nn.Module):
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

25
comfy/ldm/flux/redux.py Normal file
View File

@@ -0,0 +1,25 @@
import torch
import comfy.ops
ops = comfy.ops.manual_cast
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x

View File

@@ -494,8 +494,9 @@ class AsymmDiTJoint(nn.Module):
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
control=None, **kwargs
control=None, transformer_options={}, **kwargs
):
patches_replace = transformer_options.get("patches_replace", {})
y_feat = context
y_mask = attention_mask
sigma = timestep
@@ -515,15 +516,32 @@ class AsymmDiTJoint(nn.Module):
)
del y_mask
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
) # (B, M, D), (B, L, D)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
args["img"],
args["vec"],
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)

View File

@@ -2,12 +2,16 @@
#adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -158,8 +162,10 @@ class ResBlock(nn.Module):
*,
affine: bool = True,
attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
):
super().__init__()
self.channels = channels
@@ -170,23 +176,23 @@ class ResBlock(nn.Module):
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
out_channels=channels,
out_channels=channels // 2 if prune_bottleneck else channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
# causal=causal,
bias=bias,
causal=causal,
),
norm_fn(channels, affine=affine),
nn.SiLU(inplace=True),
PConv3d(
in_channels=channels,
in_channels=channels // 2 if prune_bottleneck else channels,
out_channels=channels,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding_mode=padding_mode,
bias=True,
# causal=causal,
bias=bias,
causal=causal,
),
)
@@ -206,6 +212,81 @@ class ResBlock(nn.Module):
return self.attn_block(x)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
) -> None:
super().__init__()
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.qk_norm = qk_norm
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.out = nn.Linear(dim, dim, bias=out_bias)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""Compute temporal self-attention.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
chunk_size: Chunk size for large tensors.
Returns:
x: Output tensor. Shape: [B, C, T, H, W].
"""
B, _, T, H, W = x.shape
if T == 1:
# No attention for single frame.
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
qkv = self.qkv(x)
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
x = self.out(x)
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
# 1D temporal attention.
x = rearrange(x, "B C t h w -> (B h w) t C")
qkv = self.qkv(x)
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
# Output: x with shape [B, num_heads, t, head_dim]
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
if self.qk_norm:
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
assert x.size(0) == q.size(0)
x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
**attn_kwargs,
) -> None:
super().__init__()
self.norm = norm_fn(dim)
self.attn = Attention(dim, **attn_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.attn(self.norm(x))
class CausalUpsampleBlock(nn.Module):
def __init__(
self,
@@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module):
return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
assert has_attention is False #NOTE: if this is ever true add back the attention code.
attn_block = None #AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
)
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
class DownsampleBlock(nn.Module):
@@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module):
out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction),
# First layer in each block always uses replicate padding
padding_mode="replicate",
bias=True,
bias=block_kwargs["bias"],
)
)
@@ -382,7 +459,7 @@ class Decoder(nn.Module):
blocks = []
first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer.
# First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]):
@@ -452,11 +529,165 @@ class Decoder(nn.Module):
return self.output_proj(x).contiguous()
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
Args:
mean: Mean of the distribution. Shape: [B, C, T, H, W].
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
"""
assert mean.shape == logvar.shape
self.mean = mean
self.logvar = logvar
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
if temperature == 0.0:
return self.mean
if noise is None:
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
else:
assert noise.device == self.mean.device
noise = noise.to(self.mean.dtype)
if temperature != 1.0:
raise NotImplementedError(f"Temperature {temperature} is not supported.")
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(self.logvar * 0.5) + self.mean
def mode(self):
return self.mean
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
):
super().__init__()
self.temporal_reductions = temporal_reductions
self.spatial_reductions = spatial_reductions
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.latent_dim = latent_dim
self.fourier_features = FourierFeatures()
ch = [mult * base_channels for mult in channel_multipliers]
num_down_blocks = len(ch) - 1
assert len(num_res_blocks) == num_down_blocks + 2
layers = (
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
if not input_is_conv_1x1
else [Conv1x1(in_channels, ch[0])]
)
assert len(prune_bottlenecks) == num_down_blocks + 2
assert len(has_attentions) == num_down_blocks + 2
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
for _ in range(num_res_blocks[0]):
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
prune_bottlenecks = prune_bottlenecks[1:]
has_attentions = has_attentions[1:]
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
for i in range(num_down_blocks):
layer = DownsampleBlock(
ch[i],
ch[i + 1],
num_res_blocks=num_res_blocks[i + 1],
temporal_reduction=temporal_reductions[i],
spatial_reduction=spatial_reductions[i],
prune_bottleneck=prune_bottlenecks[i],
has_attention=has_attentions[i],
affine=affine,
bias=bias,
padding_mode=padding_mode,
)
layers.append(layer)
# Additional blocks.
for _ in range(num_res_blocks[-1]):
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
self.layers = nn.Sequential(*layers)
# Output layers.
self.output_norm = norm_fn(ch[-1])
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
@property
def temporal_downsample(self):
return math.prod(self.temporal_reductions)
@property
def spatial_downsample(self):
return math.prod(self.spatial_reductions)
def forward(self, x) -> LatentDistribution:
"""Forward pass.
Args:
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
Returns:
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
logvar: Shape: [B, latent_dim, t, h, w].
"""
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
x = self.fourier_features(x)
x = self.layers(x)
x = self.output_norm(x)
x = F.silu(x, inplace=True)
x = self.output_proj(x)
means, logvar = torch.chunk(x, 2, dim=1)
assert means.ndim == 5
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
return LatentDistribution(means, logvar)
class VideoVAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = None #TODO once the model releases
self.encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate"
)
self.decoder = Decoder(
out_channels=3,
base_channels=128,
@@ -474,7 +705,7 @@ class VideoVAE(nn.Module):
)
def encode(self, x):
return self.encoder(x)
return self.encoder(x).mode()
def decode(self, x):
return self.decoder(x)

View File

@@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
style=None,
return_dict=False,
control=None,
transformer_options=None,
transformer_options={},
):
"""
Forward pass of the encoder.
@@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
patches_replace = transformer_options.get("patches_replace", {})
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None
if control:
controls = control.get("output", None)
@@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)

View File

@@ -0,0 +1,514 @@
import torch
from torch import nn
import comfy.ldm.modules.attention
from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None, device=None, operations=None,
):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
if cond_proj_dim is not None:
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
else:
self.cond_proj = None
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
if post_act_fn is None:
self.post_act = None
# else:
# self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
return timesteps_emb
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
)
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
def forward(self, x):
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
super().__init__()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.attn_precision = attn_precision
self.heads = heads
self.dim_head = dim_head
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
return x
def get_fractional_positions(indices_grid, max_pos):
fractional_positions = torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype
fractional_positions = get_fractional_positions(indices_grid, max_pos)
start = 1
end = theta
device = fractional_positions.device
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
class LTXVModel(torch.nn.Module):
def __init__(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
)
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
# attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for d in range(num_layers)
]
)
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = 0.0
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = self.proj_out(x)
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],
output_width=orig_shape[4],
output_num_frames=orig_shape[2],
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from einops import rearrange
from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
class Patchifier(ABC):
def __init__(self, patch_size: int):
super().__init__()
self._patch_size = (1, patch_size, patch_size)
@abstractmethod
def patchify(
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
) -> Tuple[Tensor, Tensor]:
pass
@abstractmethod
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
pass
@property
def patch_size(self):
return self._patch_size
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
):
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w)
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
class SymmetricPatchifier(Patchifier):
def patchify(
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2]
latents = rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q) ",
f=output_num_frames,
h=output_height,
w=output_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents

View File

@@ -0,0 +1,64 @@
from typing import Tuple, Union
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = ops.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
groups=groups,
)
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):
return self.conv.weight

View File

@@ -0,0 +1,698 @@
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Any, Mapping, Optional, Tuple, Union, List
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
):
super().__init__()
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.blocks_desc = blocks
in_channels = in_channels * patch_size**2
output_channel = base_channels
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.down_blocks = nn.ModuleList([])
for block_name, block_params in blocks:
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "compress_time":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 1, 1),
causal=True,
)
elif block_name == "compress_space":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(1, 2, 2),
causal=True,
)
elif block_name == "compress_all":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
else:
raise ValueError(f"unknown block: {block_name}")
self.down_blocks.append(block)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
causal (`bool`, *optional*, defaults to `True`):
Whether to use causal convolutions or not.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
causal: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.layers_per_block = layers_per_block
out_channels = out_channels * patch_size**2
self.causal = causal
self.blocks_desc = blocks
# Compute output channel to be product of all channel-multiplier blocks
output_channel = base_channels
for block_name, block_params in list(reversed(blocks)):
block_params = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
output_channel = output_channel * block_params.get("multiplier", 2)
self.conv_in = make_conv_nd(
dims,
in_channels,
output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.up_blocks = nn.ModuleList([])
for block_name, block_params in list(reversed(blocks)):
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
)
elif block_name == "compress_all":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 2, 2),
residual=block_params.get("residual", False),
)
else:
raise ValueError(f"unknown layer: {block_name}")
self.up_blocks.append(block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
# assert target_shape is not None, "target_shape must be provided"
sample = self.conv_in(sample, causal=self.causal)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
sample = sample.to(upscale_dtype)
for up_block in self.up_blocks:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True
) -> torch.FloatTensor:
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states, causal=causal)
return hidden_states
class DepthToSpaceUpsample(nn.Module):
def __init__(self, dims, in_channels, stride, residual=False):
super().__init__()
self.stride = stride
self.out_channels = math.prod(stride) * in_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
causal=True,
)
self.residual = residual
def forward(self, x, causal: bool = True):
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
if norm_layer == "group_norm":
self.norm1 = nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm1 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
if norm_layer == "group_norm":
self.norm2 = nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm2 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
)
if in_channels != out_channels
else nn.Identity()
)
self.norm3 = (
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
if in_channels != out_channels
else nn.Identity()
)
def forward(
self,
input_tensor: torch.FloatTensor,
causal: bool = True,
) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
hidden_states = self.norm2(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
input_tensor = self.norm3(input_tensor)
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
class processor(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self):
super().__init__()
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
)
self.per_channel_statistics = processor()
def encode(self, x):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x):
return self.decoder(self.per_channel_statistics.un_normalize(x))

View File

@@ -0,0 +1,83 @@
from typing import Tuple, Union
import torch
from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd(
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
causal=False,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == 3:
if causal:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
return ops.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == (2, 1):
return DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def make_linear_nd(
dims: int,
in_channels: int,
out_channels: int,
bias=True,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
elif dims == 3 or dims == (2, 1):
return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@@ -0,0 +1,195 @@
import math
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class DualConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if kernel_size == (1, 1, 1):
raise ValueError(
"kernel_size must be greater than 1. Use make_linear_nd instead."
)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
# Set parameters for convolutions
self.groups = groups
self.bias = bias
# Define the size of the channels after the first convolution
intermediate_channels = (
out_channels if in_channels < out_channels else in_channels
)
# Define parameters for the first convolution
self.weight1 = nn.Parameter(
torch.Tensor(
intermediate_channels,
in_channels // groups,
1,
kernel_size[1],
kernel_size[2],
)
)
self.stride1 = (1, stride[1], stride[2])
self.padding1 = (0, padding[1], padding[2])
self.dilation1 = (1, dilation[1], dilation[2])
if bias:
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
else:
self.register_parameter("bias1", None)
# Define parameters for the second convolution
self.weight2 = nn.Parameter(
torch.Tensor(
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
)
)
self.stride2 = (stride[0], 1, 1)
self.padding2 = (padding[0], 0, 0)
self.dilation2 = (dilation[0], 1, 1)
if bias:
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias2", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
if self.bias:
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
bound1 = 1 / math.sqrt(fan_in1)
nn.init.uniform_(self.bias1, -bound1, bound1)
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
bound2 = 1 / math.sqrt(fan_in2)
nn.init.uniform_(self.bias2, -bound2, bound2)
def forward(self, x, use_conv3d=False, skip_time_conv=False):
if use_conv3d:
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
else:
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
def forward_with_3d(self, x, skip_time_conv):
# First convolution
x = F.conv3d(
x,
self.weight1,
self.bias1,
self.stride1,
self.padding1,
self.dilation1,
self.groups,
)
if skip_time_conv:
return x
# Second convolution
x = F.conv3d(
x,
self.weight2,
self.bias2,
self.stride2,
self.padding2,
self.dilation2,
self.groups,
)
return x
def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape
# First 2D convolution
x = rearrange(x, "b c d h w -> (b d) c h w")
# Squeeze the depth dimension out of weight1 since it's 1
weight1 = self.weight1.squeeze(2)
# Select stride, padding, and dilation for the 2D convolution
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
# Reshape weight2 to match the expected dimensions for conv1d
weight2 = self.weight2.squeeze(-1).squeeze(-1)
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self):
return self.weight2
def test_dual_conv3d_consistency():
# Initialize parameters
in_channels = 3
out_channels = 5
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
# Create an instance of the DualConv3d class
dual_conv3d = DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=True,
)
# Example input tensor
test_input = torch.randn(1, 3, 10, 10, 10)
# Perform forward passes with both 3D and 2D settings
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
output_2d = dual_conv3d(test_input, use_conv3d=False)
# Assert that the outputs from both methods are sufficiently close
assert torch.allclose(
output_conv3d, output_2d, atol=1e-6
), "Outputs are not consistent between 3D and 2D convolutions."

View File

@@ -0,0 +1,12 @@
import torch
from torch import nn
class PixelNorm(nn.Module):
def __init__(self, dim=1, eps=1e-8):
super(PixelNorm, self).__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)

View File

@@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
if mask.shape[1] == 1:
s1 += mask
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
@@ -372,10 +375,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
pad = 8 - mask.shape[-1] % 8
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
mask = mask_out[..., :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
@@ -393,6 +396,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return out
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2**15
else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
@@ -404,10 +414,15 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if SDP_BATCH_LIMIT >= q.shape[0]:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out

View File

@@ -234,6 +234,8 @@ def efficient_dot_product_attention(
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
if mask.shape[1] == 1:
return mask
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]

View File

@@ -49,10 +49,20 @@ def load_lora(lora, to_load):
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
reshape_name = "{}.reshape_weight".format(x)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
@@ -72,6 +82,10 @@ def load_lora(lora, to_load):
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
@@ -82,7 +96,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@@ -193,6 +207,12 @@ def load_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
set_weight_name = "{}.set_weight".format(x)
set_weight = lora.get(set_weight_name, None)
if set_weight is not None:
patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name)
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
@@ -282,11 +302,14 @@ def model_lora_keys_unet(model, key_map={}):
sdk = sd.keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
@@ -344,6 +367,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map
@@ -440,10 +469,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set":
weight.copy_(v[0])
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
reshape = v[5]
if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:

17
comfy/lora_convert.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
return sd

View File

@@ -30,6 +30,7 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.model_management
import comfy.conds
@@ -153,8 +154,7 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs):
return None
def extra_conds(self, **kwargs):
out = {}
def concat_cond(self, **kwargs):
if len(self.concat_keys) > 0:
cond_concat = []
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
@@ -193,7 +193,14 @@ class BaseModel(torch.nn.Module):
elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
return data
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
if concat_cond is not None:
out['c_concat'] = comfy.conds.CONDNoiseShape(concat_cond)
adm = self.encode_adm(**kwargs)
if adm is not None:
@@ -523,9 +530,7 @@ class SD_X4Upscaler(BaseModel):
return out
class IP2P:
def extra_conds(self, **kwargs):
out = {}
def concat_cond(self, **kwargs):
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
@@ -537,18 +542,15 @@ class IP2P:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
return out
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
@@ -709,6 +711,44 @@ class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
def concat_cond(self, **kwargs):
try:
#Handle Flux control loras dynamically changing the img_in weight.
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
except:
#Some cases like tensorrt might not have the weights accessible
num_channels = self.model_config.unet_config["in_channels"]
out_channels = self.model_config.unet_config["out_channels"]
if num_channels <= out_channels:
return None
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
image = self.process_latent_in(image)
if num_channels <= out_channels * 2:
return image
#inpaint model
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.ones_like(noise)[:, :1]
mask = torch.mean(mask, dim=1, keepdim=True)
print(mask.shape)
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
@@ -734,3 +774,23 @@ class GenmoMochi(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class LTXV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guiding_latent = kwargs.get("guiding_latent", None)
if guiding_latent is not None:
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
return out

View File

@@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
patch_size = 2
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
@@ -177,6 +183,10 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["rope_theta"] = 10000.0
return dit_config
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -321,8 +331,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
if scaled_fp8_weight is not None:
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
@@ -540,7 +551,11 @@ def model_config_from_diffusers_unet(state_dict):
def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {}
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
@@ -549,10 +564,6 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else:
return None

View File

@@ -628,6 +628,10 @@ def maximum_vram_for_weights(device=None):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
return torch.float32
if args.fp64_unet:
return torch.float64
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -674,7 +678,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
@@ -896,7 +900,7 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
upcast = True
except:
pass

View File

@@ -367,20 +367,33 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
for name, param in m.named_parameters(recurse=True):
if name not in params:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
loading = []
for n, m in self.model.named_modules():
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
loading.append((comfy.model_management.module_size(m), n, m))
loading = self._load_list()
load_completely = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
params = x[3]
module_mem = x[0]
lowvram_weight = False
@@ -416,22 +429,22 @@ class ModelPatcher:
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m))
load_completely.append((module_mem, n, m, params))
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
m = x[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
params = x[3]
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to)
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -505,14 +518,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
unload_list = []
for n, m in self.model.named_modules():
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
unload_list.append((module_mem, n, m))
unload_list = self._load_list()
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
@@ -520,32 +526,42 @@ class ModelPatcher:
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
params = unload[3]
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]:
move_weight = True
for param in params:
key = "{}.{}".format(n, param)
bk = self.backup.get(key, None)
if bk is not None:
if not lowvram_possible:
move_weight = False
break
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
m.to(device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter

View File

@@ -2,6 +2,25 @@ import torch
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
@@ -48,7 +67,7 @@ class CONST:
return latent / (1.0 - sigma)
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
def __init__(self, model_config=None, zsnr=None):
super().__init__()
if model_config is not None:
@@ -61,11 +80,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
linear_end = sampling_settings.get("linear_end", 0.012)
timesteps = sampling_settings.get("timesteps", 1000)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
if zsnr is None:
zsnr = sampling_settings.get("zsnr", False)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
if given_betas is not None:
betas = given_betas
else:
@@ -83,6 +105,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
if zsnr:
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):

View File

@@ -1,14 +1,10 @@
import torch
import comfy.model_management
import comfy.conds
import comfy.utils
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
def get_models_from_cond(cond, model_type):
models = []

View File

@@ -8,6 +8,7 @@ from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import yaml
import comfy.utils
@@ -27,12 +28,16 @@ import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.model_patcher
import comfy.lora
import comfy.lora_convert
import comfy.t2i_adapter.adapter
import comfy.taesd.taesd
import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = {}
if model is not None:
@@ -40,6 +45,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
@@ -171,6 +177,7 @@ class VAE:
self.downscale_ratio = 8
self.upscale_ratio = 8
self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
@@ -240,16 +247,30 @@ class VAE:
self.output_channels = 2
self.upscale_ratio = 2048
self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
self.first_stage_model = comfy.ldm.genmo.vae.model.VideoVAE()
self.latent_channels = 12
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -349,29 +370,52 @@ class VAE:
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in)
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1,-1)
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@@ -392,6 +436,12 @@ class VAE:
def get_sd(self):
return self.first_stage_model.state_dict()
def spacial_compression_decode(self):
try:
return self.upscale_ratio[-1]
except:
return self.upscale_ratio
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@@ -405,6 +455,8 @@ def load_style_model(ckpt_path):
keys = model_data.keys()
if "style_embedding" in keys:
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
elif "redux_down.weight" in keys:
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data)
@@ -418,6 +470,7 @@ class CLIPType(Enum):
HUNYUAN_DIT = 5
FLUX = 6
MOCHI = 7
LTXV = 8
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
@@ -496,6 +549,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer

View File

@@ -11,6 +11,7 @@ import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
from . import supported_models_base
from . import latent_formats
@@ -197,6 +198,8 @@ class SDXL(supported_models_base.BASE):
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
return model_base.ModelType.V_PREDICTION_EDM
elif "v_pred" in state_dict:
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
self.sampling_settings["zsnr"] = True
return model_base.ModelType.V_PREDICTION
else:
return model_base.ModelType.EPS
@@ -656,6 +659,15 @@ class Flux(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class FluxInpaint(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
"in_channels": 96,
}
supported_inference_dtypes = [torch.bfloat16, torch.float32]
class FluxSchnell(Flux):
unet_config = {
"image_model": "flux",
@@ -700,7 +712,34 @@ class GenmoMochi(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
class LTXV(supported_models_base.BASE):
unet_config = {
"image_model": "ltxv",
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
sampling_settings = {
"shift": 2.37,
}
unet_extra_config = {}
latent_format = latent_formats.LTXV
memory_usage_factor = 2.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
models += [SVD_img2vid]

View File

@@ -12,7 +12,7 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):

18
comfy/text_encoders/lt.py Normal file
View File

@@ -0,0 +1,18 @@
from comfy import sd1_clip
import os
from transformers import T5TokenizerFast
import comfy.text_encoders.genmo
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)

View File

@@ -316,10 +316,18 @@ MMDIT_MAP_BLOCK = {
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
("x_block.attn.proj.bias", "attn.to_out.0.bias"),
("x_block.attn.proj.weight", "attn.to_out.0.weight"),
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
("x_block.mlp.fc2.bias", "ff.net.2.bias"),
@@ -349,6 +357,12 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
k = "{}.attn2.".format(block_from)
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
for k in MMDIT_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
@@ -834,3 +848,24 @@ class ProgressBar:
def update(self, value):
self.update_absolute(self.current + value)
def reshape_mask(input_mask, output_shape):
dims = len(output_shape) - 2
if dims == 1:
scale_mode = "linear"
if dims == 2:
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "bilinear"
if dims == 3:
if len(input_mask.shape) < 5:
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
scale_mode = "trilinear"
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
return mask

181
comfy_extras/nodes_lt.py Normal file
View File

@@ -0,0 +1,181 @@
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import math
class EmptyLTXVLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video/ltxv"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, )
class LTXVImgToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"image": ("IMAGE",),
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t
return (positive, negative, {"samples": latent}, )
class LTXVConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
CATEGORY = "conditioning/video_models"
def append(self, positive, negative, frame_rate):
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
return (positive, negative)
class ModelSamplingLTXV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
},
"optional": {"latent": ("LATENT",), }
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, latent=None):
m = model.clone()
if latent is None:
tokens = 4096
else:
tokens = math.prod(latent["samples"].shape[2:])
x1 = 1024
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (tokens) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class LTXVScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
"stretch": ("BOOLEAN", {
"default": True,
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
}),
"terminal": (
"FLOAT",
{
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
"tooltip": "The terminal value of the sigmas after stretching."
},
),
},
"optional": {"latent": ("LATENT",), }
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None:
tokens = 4096
else:
tokens = math.prod(latent["samples"].shape[2:])
sigmas = torch.linspace(1.0, 0.0, steps + 1)
x1 = 1024
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = (tokens) * mm + b
power = 1
sigmas = torch.where(
sigmas != 0,
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
0,
)
# Stretch sigmas so that its final value matches the given terminal value.
if stretch:
non_zero_mask = sigmas != 0
non_zero_sigmas = sigmas[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched
return (sigmas,)
NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
"LTXVImgToVideo": LTXVImgToVideo,
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
}

View File

@@ -3,9 +3,6 @@ import torch
import comfy.model_management
class EmptyMochiLatentVideo:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
@@ -15,10 +12,10 @@ class EmptyMochiLatentVideo:
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/mochi"
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=self.device)
latent = torch.zeros([batch_size, 12, ((length - 1) // 6) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {

View File

@@ -26,8 +26,8 @@ class X0(comfy.model_sampling.EPS):
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
def __init__(self, model_config=None):
super().__init__(model_config)
def __init__(self, model_config=None, zsnr=None):
super().__init__(model_config, zsnr=zsnr)
self.skip_steps = self.num_timesteps // self.original_timesteps
@@ -51,25 +51,6 @@ class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete)
return log_sigma.exp().to(timestep.device)
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
@@ -100,9 +81,7 @@ class ModelSamplingDiscrete:
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
m.add_object_patch("model_sampling", model_sampling)
return (m, )

View File

@@ -75,6 +75,34 @@ class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeAuraflow(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["init_x_linear."] = argument
arg_dict["positional_encoding"] = argument
arg_dict["cond_seq_linear."] = argument
arg_dict["register_tokens"] = argument
arg_dict["t_embedder."] = argument
for i in range(4):
arg_dict["double_layers.{}.".format(i)] = argument
for i in range(32):
arg_dict["single_layers.{}.".format(i)] = argument
arg_dict["modF."] = argument
arg_dict["final_linear."] = argument
return {"required": arg_dict}
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@@ -124,11 +152,58 @@ class ModelMergeSD35_Large(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeMochiPreview(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["pos_frequencies."] = argument
arg_dict["t_embedder."] = argument
arg_dict["t5_y_embedder."] = argument
arg_dict["t5_yproj."] = argument
for i in range(48):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["patchify_proj."] = argument
arg_dict["adaln_single."] = argument
arg_dict["caption_projection."] = argument
for i in range(28):
arg_dict["transformer_blocks.{}.".format(i)] = argument
arg_dict["scale_shift_table"] = argument
arg_dict["proj_out."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeAuraflow": ModelMergeAuraflow,
"ModelMergeFlux1": ModelMergeFlux1,
"ModelMergeSD35_Large": ModelMergeSD35_Large,
"ModelMergeMochiPreview": ModelMergeMochiPreview,
"ModelMergeLTXV": ModelMergeLTXV,
}

View File

@@ -57,12 +57,24 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
total = mask.shape[-1]
x = round(math.sqrt((lh / lw) * total))
xx = None
for i in range(0, math.floor(math.sqrt(total) / 2)):
for j in [(x + i), max(1, x - i)]:
if total % j == 0:
xx = j
break
if xx is not None:
break
x = xx
y = total // x
# Reshape
mask = (
mask.reshape(b, *mid_shape)
mask.reshape(b, x, y)
.unsqueeze(1)
.type(attn.dtype)
)

View File

@@ -3,24 +3,29 @@ import comfy.sd
import comfy.model_management
import nodes
import torch
import re
import comfy_extras.nodes_slg
class TripleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), )
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("clip", clip_name3)
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)
class EmptySD3LatentImage:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@@ -39,6 +44,7 @@ class EmptySD3LatentImage:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
class CLIPTextEncodeSD3:
@classmethod
def INPUT_TYPES(s):
@@ -95,7 +101,8 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
CATEGORY = "conditioning/controlnet"
DEPRECATED = True
class SkipLayerGuidanceSD3:
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
@@ -110,47 +117,12 @@ class SkipLayerGuidanceSD3:
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
FUNCTION = "skip_guidance_sd3"
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, layers, scale, start_percent, end_percent):
if layers == "" or layers == None:
return (model, )
# check if layer is comma separated integers
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()
for layer in layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
model_sampling.percent_to_sigma(start_percent)
sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result
layers = re.findall(r'\d+', layers)
layers = [int(i) for i in layers]
m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, )
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
NODE_CLASS_MAPPINGS = {

78
comfy_extras/nodes_slg.py Normal file
View File

@@ -0,0 +1,78 @@
import comfy.model_patcher
import comfy.samplers
import re
class SkipLayerGuidanceDiT:
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
EXPERIMENTAL = True
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""):
# check if layer is comma separated integers
def skip(args, extra_args):
return args
model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
double_layers = re.findall(r'\d+', double_layers)
double_layers = [int(i) for i in double_layers]
single_layers = re.findall(r'\d+', single_layers)
single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0:
return (model, )
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()
for layer in double_layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
for layer in single_layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer)
model_sampling.percent_to_sigma(start_percent)
sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result
m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 116 KiB

View File

@@ -18,7 +18,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["text_encoders"] = ([os.path.join(models_dir, "text_encoders"), os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["diffusion_models"] = ([os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
@@ -81,7 +81,8 @@ extension_mimetypes_cache = {
}
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
legacy = {"unet": "diffusion_models",
"clip": "text_encoders"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory):

View File

@@ -47,7 +47,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = torch.nn.functional.linear(x0[0].permute(1, 2, 0), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
if x0.ndim == 5:
x0 = x0[0, :, 0]
else:
x0 = x0[0]
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
return preview_to_image(latent_image)

View File

@@ -71,6 +71,7 @@ if os.name == "nt":
if __name__ == "__main__":
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.deterministic:

View File

@@ -1,2 +0,0 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@@ -1,234 +0,0 @@
#NOTE: This was an experiment and WILL BE REMOVED
from __future__ import annotations
import aiohttp
import os
import traceback
import logging
from folder_paths import folder_names_and_paths, get_folder_paths
import re
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
import time
from dataclasses import dataclass
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
status: str
progress_percentage: float
message: str
already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model name",
False
)
if not model_directory in folder_names_and_paths:
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file:
return existing_file
try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(model_name, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]:
os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(model_name, status)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
last_update_time = time.time()
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(model_name, status)
last_update_time = time.time()
temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
os.rename(temp_file_path, file_path)
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(model_name, status)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(model_name, status)
return status
def validate_filename(filename: str)-> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate
Returns:
bool: True if the filename is valid, False otherwise
"""
if not filename.lower().endswith(('.sft', '.safetensors')):
return False
# Check if the filename is empty, None, or just whitespace
if not filename or not filename.strip():
return False
# Check for any directory traversal attempts or invalid characters
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
return False
# Check if the filename starts with a dot (hidden file)
if filename.startswith('.'):
return False
# Use a whitelist of allowed characters
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
return False
# Ensure the filename isn't too long
if len(filename) > 255:
return False
return True

View File

@@ -290,15 +290,22 @@ class VAEDecodeTiled:
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples, tile_size):
return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
def decode(self, vae, samples, tile_size, overlap=64):
if tile_size < overlap * 4:
overlap = tile_size // 4
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
class VAEEncode:
@classmethod
@@ -376,6 +383,7 @@ class InpaintModelConditioning:
"vae": ("VAE", ),
"pixels": ("IMAGE", ),
"mask": ("MASK", ),
"noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
}}
RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
@@ -384,7 +392,7 @@ class InpaintModelConditioning:
CATEGORY = "conditioning/inpaint"
def encode(self, positive, negative, pixels, vae, mask):
def encode(self, positive, negative, pixels, vae, mask, noise_mask=True):
x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
@@ -408,7 +416,8 @@ class InpaintModelConditioning:
out_latent = {}
out_latent["samples"] = orig_latent
out_latent["noise_mask"] = mask
if noise_mask:
out_latent["noise_mask"] = mask
out = []
for conditioning in [positive, negative]:
@@ -888,14 +897,16 @@ class UNETLoader:
class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5"
def load_clip(self, clip_name, type="stable_diffusion"):
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
@@ -905,18 +916,20 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
elif type == "mochi":
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_path = folder_paths.get_full_path_or_raise("clip", clip_name)
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
class DualCLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
"clip_name2": (folder_paths.get_filename_list("clip"), ),
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux"], ),
}}
RETURN_TYPES = ("CLIP",)
@@ -924,9 +937,11 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5"
def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path_or_raise("clip", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("clip", clip_name2)
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
@@ -956,15 +971,19 @@ class CLIPVisionEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",)
"image": ("IMAGE",),
"crop": (["center", "none"],)
}}
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, clip_vision, image):
output = clip_vision.encode_image(image)
def encode(self, clip_vision, image, crop):
crop_image = True
if crop != "center":
crop_image = False
output = clip_vision.encode_image(image, crop=crop_image)
return (output,)
class StyleModelLoader:
@@ -989,14 +1008,19 @@ class StyleModelApply:
return {"required": {"conditioning": ("CONDITIONING", ),
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"strength_type": (["multiply"], ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
if strength_type == "multiply":
cond *= strength
c = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
@@ -1957,6 +1981,12 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageInvert": "Invert Image",
"ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images",
"ImageCrop": "Image Crop",
"ImageBlend": "Image Blend",
"ImageBlur": "Image Blur",
"ImageQuantize": "Image Quantize",
"ImageSharpen": "Image Sharpen",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
# _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
@@ -2117,6 +2147,8 @@ def init_builtin_extra_nodes():
"nodes_lora_extract.py",
"nodes_torch_compile.py",
"nodes_mochi.py",
"nodes_slg.py",
"nodes_lt.py",
]
import_failed = []

View File

@@ -29,7 +29,6 @@ import comfy.model_management
import node_helpers
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
@@ -152,7 +151,7 @@ class PromptServer():
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
self.user_manager = UserManager()
self.internal_routes = InternalRoutes()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.loop = loop
@@ -676,36 +675,6 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200)
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
# NOTE: This was an experiment and WILL BE REMOVED
@routes.post("/internal/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus):
payload = status.to_dict()
payload['download_path'] = filename
await self.send_json("download_progress", payload)
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session
if session is None:
logging.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task
return web.json_response(task.result().to_dict())
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout

View File

@@ -1,337 +0,0 @@
import pytest
import tempfile
import aiohttp
from aiohttp import ClientResponse
import itertools
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
import folder_paths
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
class AsyncIteratorMock:
"""
A mock class that simulates an asynchronous iterator.
This is used to mimic the behavior of aiohttp's content iterator.
"""
def __init__(self, seq):
# Convert the input sequence into an iterator
self.iter = iter(seq)
def __aiter__(self):
# This method is called when 'async for' is used
return self
async def __anext__(self):
# This method is called for each iteration in an 'async for' loop
try:
return next(self.iter)
except StopIteration:
# This is the asynchronous equivalent of StopIteration
raise StopAsyncIteration
class ContentMock:
"""
A mock class that simulates the content attribute of an aiohttp ClientResponse.
This class provides the iter_chunked method which returns an async iterator of chunks.
"""
def __init__(self, chunks):
# Store the chunks that will be returned by the iterator
self.chunks = chunks
def iter_chunked(self, chunk_size):
# This method mimics aiohttp's content.iter_chunked()
# For simplicity in testing, we ignore chunk_size and just return our predefined chunks
return AsyncIteratorMock(self.chunks)
@pytest.mark.asyncio
async def test_download_model_success(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.status = 200
mock_response.headers = {'Content-Length': '1000'}
# Create a mock for content that returns an async iterator directly
chunks = [b'a' * 500, b'b' * 300, b'c' * 200]
mock_response.content = ContentMock(chunks)
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
time_values = itertools.count(0, 0.1)
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('folder_paths.folder_names_and_paths', fake_paths), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
temp_dir,
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message == 'Successfully downloaded model.sft'
assert result.status == 'completed'
assert result.already_existed is False
# Check progress callback calls
assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion
# Check initial call
mock_progress_callback.assert_any_call(
'model.sft',
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False)
)
# Check final call
mock_progress_callback.assert_any_call(
'model.sft',
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
)
mock_file_path = os.path.join(temp_dir, 'model.sft')
assert os.path.exists(mock_file_path)
with open(mock_file_path, 'rb') as mock_file:
assert mock_file.read() == b''.join(chunks)
os.remove(mock_file_path)
# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft')
@pytest.mark.asyncio
async def test_download_model_url_request_failure(temp_dir):
# Mock dependencies
mock_response = AsyncMock(spec=ClientResponse)
mock_response.status = 404 # Simulate a "Not Found" error
mock_get = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()
fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}
# Mock the create_model_path function
with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('folder_paths.folder_names_and_paths', fake_paths):
# Call the function
result = await download_model(
mock_get,
'model.safetensors',
'http://example.com/model.safetensors',
'checkpoints',
temp_dir,
mock_progress_callback
)
# Assert the expected behavior
assert isinstance(result, DownloadModelStatus)
assert result.status == 'error'
assert result.message == 'Failed to download model.safetensors. Status code: 404'
assert result.already_existed is False
# Check that progress_callback was called with the correct arguments
mock_progress_callback.assert_any_call(
'model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.PENDING,
progress_percentage=0,
message='Starting download of model.safetensors',
already_existed=False
)
)
mock_progress_callback.assert_called_with(
'model.safetensors',
DownloadModelStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
message='Failed to download model.safetensors. Status code: 404',
already_existed=False
)
)
# Verify that the get method was called with the correct URL
mock_get.assert_called_once_with('http://example.com/model.safetensors')
@pytest.mark.asyncio
async def test_download_model_invalid_model_subdirectory():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'../bad_path',
'../bad_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message.startswith('Invalid or unrecognized model directory')
assert result.status == 'error'
assert result.already_existed is False
@pytest.mark.asyncio
async def test_download_model_invalid_folder_path():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
'invalid_path',
mock_progress_callback
)
# Assert the result
assert isinstance(result, DownloadModelStatus)
assert result.message.startswith("Invalid folder path")
assert result.status == 'error'
assert result.already_existed is False
def test_create_model_path(tmp_path, monkeypatch):
model_name = "model.safetensors"
folder_path = os.path.join(tmp_path, "mock_dir")
file_path = create_model_path(model_name, folder_path)
assert file_path == os.path.join(folder_path, "model.safetensors")
assert os.path.exists(os.path.dirname(file_path))
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("../path_traversal.safetensors", folder_path)
with pytest.raises(Exception, match="Invalid model directory"):
create_model_path("/etc/some_root_path", folder_path)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_exists(tmp_path):
file_path = tmp_path / "existing_model.sft"
file_path.touch() # Create an empty file
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback)
assert result is not None
assert result.status == "completed"
assert result.message == "existing_model.sft already exists"
assert result.already_existed is True
mock_callback.assert_called_once_with(
"existing_model.sft",
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True)
)
@pytest.mark.asyncio
async def test_check_file_exists_when_file_does_not_exist(tmp_path):
file_path = tmp_path / "non_existing_model.sft"
mock_callback = AsyncMock()
result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback)
assert result is None
mock_callback.assert_not_called()
@pytest.mark.asyncio
async def test_track_download_progress_no_content_length(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
chunks = [b'a' * 500, b'b' * 500]
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock()
full_path = os.path.join(temp_dir, 'model.sft')
result = await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=0.1
)
assert result.status == "completed"
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
)
@pytest.mark.asyncio
async def test_track_download_progress_interval(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
chunks = [b'a' * 100] * 10
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)
mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
# Create a mock time function that returns incremental float values
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks
full_path = os.path.join(temp_dir, 'model.sft')
with patch('time.time', mock_time):
await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=1.0
)
assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)
# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
# Verify the first and last calls
first_call = mock_callback.call_args_list[0]
assert first_call[0][1].status == "in_progress"
# Allow for some initial progress, but it should be less than 50%
assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%"
last_call = mock_callback.call_args_list[-1]
assert last_call[0][1].status == "completed"
assert last_call[0][1].progress_percentage == 100
@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),
("valid model.safetensors", True), # Test with space
("UPPERCASE_MODEL.SAFETENSORS", True),
("model_with.multiple.dots.pt", False),
("", False), # Empty string
("../../../etc/passwd", False), # Path traversal attempt
("/etc/passwd", False), # Absolute path
("\\windows\\system32\\config\\sam", False), # Windows path
(".hidden_file.pt", False), # Hidden file
("invalid<char>.ckpt", False), # Invalid character
("invalid?.ckpt", False), # Another invalid character
("very" * 100 + ".safetensors", False), # Too long filename
("\nmodel_with_newline.pt", False), # Newline character
("model_with_emoji😊.pt", False), # Emoji in filename
])
def test_validate_filename(filename, expected):
assert validate_filename(filename) == expected

View File

@@ -14,7 +14,7 @@ def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
) if file else tmp_path
return um
@@ -80,9 +80,7 @@ async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
["subdir/file1.txt", "subdir", "file1.txt"]
]
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
@@ -118,3 +116,116 @@ async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt", data=content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was created with correct content
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == content
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
new_content = b"updated content"
resp = await client.post("/userdata/test.txt", data=new_content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was overwritten
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == new_content
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
assert resp.status == 409
# Verify original content unchanged
with open(tmp_path / "test.txt", "r") as f:
assert f.read() == "initial content"
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
assert resp.status == 200
result = await resp.json()
assert result["path"] == "test.txt"
assert result["size"] == len(content)
assert "modified" in result
async def test_move_userdata(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt")
assert resp.status == 200
assert await resp.text() == '"dest.txt"'
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create source and destination files
with open(tmp_path / "source.txt", "w") as f:
f.write("source content")
with open(tmp_path / "dest.txt", "w") as f:
f.write("destination content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
assert resp.status == 409
# Verify files remain unchanged
with open(tmp_path / "source.txt", "r") as f:
assert f.read() == "source content"
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "destination content"
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
assert resp.status == 200
result = await resp.json()
assert result["path"] == "dest.txt"
assert result["size"] == len("test content")
assert "modified" in result
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"

View File

@@ -8,7 +8,7 @@ from folder_paths import models_dir, user_directory, output_directory
@pytest.fixture
def internal_routes():
return InternalRoutes()
return InternalRoutes(None)
@pytest.fixture
def aiohttp_client_factory(aiohttp_client, internal_routes):
@@ -102,7 +102,7 @@ async def test_file_service_initialization():
# Create a mock instance
mock_file_service_instance = MagicMock(spec=FileService)
MockFileService.return_value = mock_file_service_instance
internal_routes = InternalRoutes()
internal_routes = InternalRoutes(None)
# Check if FileService was initialized with the correct parameters
MockFileService.assert_called_once_with({
@@ -112,4 +112,4 @@ async def test_file_service_initialization():
})
# Verify that the file_service attribute of InternalRoutes is set
assert internal_routes.file_service == mock_file_service_instance
assert internal_routes.file_service == mock_file_service_instance

103
web/assets/ExtensionPanel-BmKi_NKS.js generated vendored
View File

@@ -1,103 +0,0 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, bQ as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bR as script$1, A as createBaseVNode, x as createBlock, N as Fragment, O as renderList, a6 as toDisplayString, aw as createTextVNode, j as createCommentVNode, D as script$4 } from "./index-BHayQCxv.js";
import { s as script, a as script$2, b as script$3 } from "./index-CwRXxFdA.js";
import "./index-C_wOqB0f.js";
const _hoisted_1 = { class: "extension-panel" };
const _hoisted_2 = { class: "mt-4" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ExtensionPanel",
setup(__props) {
const extensionStore = useExtensionStore();
const settingStore = useSettingStore();
const editingEnabledExtensions = ref({});
onMounted(() => {
extensionStore.extensions.forEach((ext) => {
editingEnabledExtensions.value[ext.name] = extensionStore.isExtensionEnabled(ext.name);
});
});
const changedExtensions = computed(() => {
return extensionStore.extensions.filter(
(ext) => editingEnabledExtensions.value[ext.name] !== extensionStore.isExtensionEnabled(ext.name)
);
});
const hasChanges = computed(() => {
return changedExtensions.value.length > 0;
});
const updateExtensionStatus = /* @__PURE__ */ __name(() => {
const editingDisabledExtensionNames = Object.entries(
editingEnabledExtensions.value
).filter(([_, enabled]) => !enabled).map(([name]) => name);
settingStore.set("Comfy.Extension.Disabled", [
...extensionStore.inactiveDisabledExtensionNames,
...editingDisabledExtensionNames
]);
}, "updateExtensionStatus");
const applyChanges = /* @__PURE__ */ __name(() => {
window.location.reload();
}, "applyChanges");
return (_ctx, _cache) => {
return openBlock(), createElementBlock("div", _hoisted_1, [
createVNode(unref(script$2), {
value: unref(extensionStore).extensions,
stripedRows: "",
size: "small"
}, {
default: withCtx(() => [
createVNode(unref(script), {
field: "name",
header: _ctx.$t("extensionName"),
sortable: ""
}, null, 8, ["header"]),
createVNode(unref(script), { pt: {
bodyCell: "flex items-center justify-end"
} }, {
body: withCtx((slotProps) => [
createVNode(unref(script$1), {
modelValue: editingEnabledExtensions.value[slotProps.data.name],
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
onChange: updateExtensionStatus
}, null, 8, ["modelValue", "onUpdate:modelValue"])
]),
_: 1
})
]),
_: 1
}, 8, ["value"]),
createBaseVNode("div", _hoisted_2, [
hasChanges.value ? (openBlock(), createBlock(unref(script$3), {
key: 0,
severity: "info"
}, {
default: withCtx(() => [
createBaseVNode("ul", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(changedExtensions.value, (ext) => {
return openBlock(), createElementBlock("li", {
key: ext.name
}, [
createBaseVNode("span", null, toDisplayString(unref(extensionStore).isExtensionEnabled(ext.name) ? "[-]" : "[+]"), 1),
createTextVNode(" " + toDisplayString(ext.name), 1)
]);
}), 128))
])
]),
_: 1
})) : createCommentVNode("", true),
createVNode(unref(script$4), {
label: _ctx.$t("reloadToApplyChanges"),
icon: "pi pi-refresh",
onClick: applyChanges,
disabled: !hasChanges.value,
text: "",
fluid: "",
severity: "danger"
}, null, 8, ["label", "disabled"])
])
]);
};
}
});
export {
_sfc_main as default
};
//# sourceMappingURL=ExtensionPanel-BmKi_NKS.js.map

View File

@@ -1 +0,0 @@
{"version":3,"file":"ExtensionPanel-BmKi_NKS.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <div class=\"extension-panel\">\n <DataTable :value=\"extensionStore.extensions\" stripedRows size=\"small\">\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n <div class=\"mt-4\">\n <Message v-if=\"hasChanges\" severity=\"info\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n </Message>\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n icon=\"pi pi-refresh\"\n @click=\"applyChanges\"\n :disabled=\"!hasChanges\"\n text\n fluid\n severity=\"danger\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;AAmDA,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

117
web/assets/ExtensionPanel-DsD42OtO.js generated vendored Normal file
View File

@@ -0,0 +1,117 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, r as ref, c6 as FilterMatchMode, ca as useExtensionStore, u as useSettingStore, o as onMounted, q as computed, g as openBlock, x as createBlock, y as withCtx, i as createVNode, c7 as SearchBox, z as unref, bT as script, A as createBaseVNode, h as createElementBlock, O as renderList, a6 as toDisplayString, aw as createTextVNode, N as Fragment, D as script$1, j as createCommentVNode, bV as script$3, c8 as _sfc_main$1 } from "./index-CoOvI8ZH.js";
import { s as script$2, a as script$4 } from "./index-DK6Kev7f.js";
import "./index-D4DWQPPQ.js";
const _hoisted_1 = { class: "flex justify-end" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ExtensionPanel",
setup(__props) {
const filters = ref({
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
});
const extensionStore = useExtensionStore();
const settingStore = useSettingStore();
const editingEnabledExtensions = ref({});
onMounted(() => {
extensionStore.extensions.forEach((ext) => {
editingEnabledExtensions.value[ext.name] = extensionStore.isExtensionEnabled(ext.name);
});
});
const changedExtensions = computed(() => {
return extensionStore.extensions.filter(
(ext) => editingEnabledExtensions.value[ext.name] !== extensionStore.isExtensionEnabled(ext.name)
);
});
const hasChanges = computed(() => {
return changedExtensions.value.length > 0;
});
const updateExtensionStatus = /* @__PURE__ */ __name(() => {
const editingDisabledExtensionNames = Object.entries(
editingEnabledExtensions.value
).filter(([_, enabled]) => !enabled).map(([name]) => name);
settingStore.set("Comfy.Extension.Disabled", [
...extensionStore.inactiveDisabledExtensionNames,
...editingDisabledExtensionNames
]);
}, "updateExtensionStatus");
const applyChanges = /* @__PURE__ */ __name(() => {
window.location.reload();
}, "applyChanges");
return (_ctx, _cache) => {
return openBlock(), createBlock(_sfc_main$1, {
value: "Extension",
class: "extension-panel"
}, {
header: withCtx(() => [
createVNode(SearchBox, {
modelValue: filters.value["global"].value,
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
placeholder: _ctx.$t("searchExtensions") + "..."
}, null, 8, ["modelValue", "placeholder"]),
hasChanges.value ? (openBlock(), createBlock(unref(script), {
key: 0,
severity: "info",
"pt:text": "w-full"
}, {
default: withCtx(() => [
createBaseVNode("ul", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(changedExtensions.value, (ext) => {
return openBlock(), createElementBlock("li", {
key: ext.name
}, [
createBaseVNode("span", null, toDisplayString(unref(extensionStore).isExtensionEnabled(ext.name) ? "[-]" : "[+]"), 1),
createTextVNode(" " + toDisplayString(ext.name), 1)
]);
}), 128))
]),
createBaseVNode("div", _hoisted_1, [
createVNode(unref(script$1), {
label: _ctx.$t("reloadToApplyChanges"),
onClick: applyChanges,
outlined: "",
severity: "danger"
}, null, 8, ["label"])
])
]),
_: 1
})) : createCommentVNode("", true)
]),
default: withCtx(() => [
createVNode(unref(script$4), {
value: unref(extensionStore).extensions,
stripedRows: "",
size: "small",
filters: filters.value
}, {
default: withCtx(() => [
createVNode(unref(script$2), {
field: "name",
header: _ctx.$t("extensionName"),
sortable: ""
}, null, 8, ["header"]),
createVNode(unref(script$2), { pt: {
bodyCell: "flex items-center justify-end"
} }, {
body: withCtx((slotProps) => [
createVNode(unref(script$3), {
modelValue: editingEnabledExtensions.value[slotProps.data.name],
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
onChange: updateExtensionStatus
}, null, 8, ["modelValue", "onUpdate:modelValue"])
]),
_: 1
})
]),
_: 1
}, 8, ["value", "filters"])
]),
_: 1
});
};
}
});
export {
_sfc_main as default
};
//# sourceMappingURL=ExtensionPanel-DsD42OtO.js.map

1
web/assets/ExtensionPanel-DsD42OtO.js.map generated vendored Normal file
View File

@@ -0,0 +1 @@
{"version":3,"file":"ExtensionPanel-DsD42OtO.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Extension\" class=\"extension-panel\">\n <template #header>\n <SearchBox\n v-model=\"filters['global'].value\"\n :placeholder=\"$t('searchExtensions') + '...'\"\n />\n <Message v-if=\"hasChanges\" severity=\"info\" pt:text=\"w-full\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n <div class=\"flex justify-end\">\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n @click=\"applyChanges\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n </template>\n <DataTable\n :value=\"extensionStore.extensions\"\n stripedRows\n size=\"small\"\n :filters=\"filters\"\n >\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport { FilterMatchMode } from '@primevue/core/api'\nimport PanelTemplate from './PanelTemplate.vue'\nimport SearchBox from '@/components/common/SearchBox.vue'\n\nconst filters = ref({\n global: { value: '', matchMode: FilterMatchMode.CONTAINS }\n})\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;AA8DA,UAAM,UAAU,IAAI;AAAA,MAClB,QAAQ,EAAE,OAAO,IAAI,WAAW,gBAAgB,SAAS;AAAA,IAAA,CAC1D;AAED,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

File diff suppressed because one or more lines are too long

1
web/assets/GraphView-BW5soyxY.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -45,7 +45,7 @@
--sidebar-icon-size: 1rem;
}
.side-tool-bar-container[data-v-37fd2fa4] {
.side-tool-bar-container[data-v-e0812a25] {
display: flex;
flex-direction: column;
align-items: center;
@@ -58,28 +58,32 @@
background-color: var(--comfy-menu-bg);
color: var(--fg-color);
}
.side-tool-bar-end[data-v-37fd2fa4] {
.side-tool-bar-end[data-v-e0812a25] {
align-self: flex-end;
margin-top: auto;
}
[data-v-b49f20b1] .p-splitter-gutter {
[data-v-7c3279c1] .p-splitter-gutter {
pointer-events: auto;
}
.side-bar-panel[data-v-b49f20b1] {
[data-v-7c3279c1] .p-splitter-gutter:hover,[data-v-7c3279c1] .p-splitter-gutter[data-p-gutter-resizing='true'] {
transition: background-color 0.2s ease 300ms;
background-color: var(--p-primary-color);
}
.side-bar-panel[data-v-7c3279c1] {
background-color: var(--bg-color);
pointer-events: auto;
}
.bottom-panel[data-v-b49f20b1] {
.bottom-panel[data-v-7c3279c1] {
background-color: var(--bg-color);
pointer-events: auto;
}
.splitter-overlay[data-v-b49f20b1] {
.splitter-overlay[data-v-7c3279c1] {
pointer-events: none;
border-style: none;
background-color: transparent;
}
.splitter-overlay-root[data-v-b49f20b1] {
.splitter-overlay-root[data-v-7c3279c1] {
position: absolute;
top: 0px;
left: 0px;
@@ -102,32 +106,6 @@
margin: -0.125rem 0.125rem;
}
.comfy-vue-node-search-container[data-v-2d409367] {
display: flex;
width: 100%;
min-width: 26rem;
align-items: center;
justify-content: center;
}
.comfy-vue-node-search-container[data-v-2d409367] * {
pointer-events: auto;
}
.comfy-vue-node-preview-container[data-v-2d409367] {
position: absolute;
left: -350px;
top: 50px;
}
.comfy-vue-node-search-box[data-v-2d409367] {
z-index: 10;
flex-grow: 1;
}
._filter-button[data-v-2d409367] {
z-index: 10;
}
._dialog[data-v-2d409367] {
min-width: 26rem;
}
.invisible-dialog-root {
width: 60%;
min-width: 24rem;
@@ -146,7 +124,7 @@
align-items: flex-start !important;
}
.node-tooltip[data-v-79ec8c53] {
.node-tooltip[data-v-c2e0098f] {
background: var(--comfy-input-bg);
border-radius: 5px;
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
@@ -162,22 +140,28 @@
z-index: 99999;
}
.p-buttongroup-vertical[data-v-444d3768] {
.p-buttongroup-vertical[data-v-94481f39] {
display: flex;
flex-direction: column;
border-radius: var(--p-button-border-radius);
overflow: hidden;
border: 1px solid var(--p-panel-border-color);
}
.p-buttongroup-vertical .p-button[data-v-444d3768] {
.p-buttongroup-vertical .p-button[data-v-94481f39] {
margin: 0;
border-radius: 0;
}
[data-v-84e785b8] .p-togglebutton::before {
.comfy-menu-hamburger[data-v-2ddd26e8] {
pointer-events: auto;
position: fixed;
z-index: 9999;
}
[data-v-783f8efe] .p-togglebutton::before {
display: none
}
[data-v-84e785b8] .p-togglebutton {
[data-v-783f8efe] .p-togglebutton {
position: relative;
flex-shrink: 0;
border-radius: 0px;
@@ -185,14 +169,14 @@
padding-left: 0.5rem;
padding-right: 0.5rem
}
[data-v-84e785b8] .p-togglebutton.p-togglebutton-checked {
[data-v-783f8efe] .p-togglebutton.p-togglebutton-checked {
border-bottom-width: 2px;
border-bottom-color: var(--p-button-text-primary-color)
}
[data-v-84e785b8] .p-togglebutton-checked .close-button,[data-v-84e785b8] .p-togglebutton:hover .close-button {
[data-v-783f8efe] .p-togglebutton-checked .close-button,[data-v-783f8efe] .p-togglebutton:hover .close-button {
visibility: visible
}
.status-indicator[data-v-84e785b8] {
.status-indicator[data-v-783f8efe] {
position: absolute;
font-weight: 700;
font-size: 1.5rem;
@@ -200,10 +184,10 @@
left: 50%;
transform: translate(-50%, -50%)
}
[data-v-84e785b8] .p-togglebutton:hover .status-indicator {
[data-v-783f8efe] .p-togglebutton:hover .status-indicator {
display: none
}
[data-v-84e785b8] .p-togglebutton .close-button {
[data-v-783f8efe] .p-togglebutton .close-button {
visibility: hidden
}
@@ -226,35 +210,35 @@
border-bottom-left-radius: 0;
}
.comfyui-queue-button[data-v-2b80bf74] .p-splitbutton-dropdown {
.comfyui-queue-button[data-v-95bc9be0] .p-splitbutton-dropdown {
border-top-right-radius: 0;
border-bottom-right-radius: 0;
}
.actionbar[data-v-2e54db00] {
.actionbar[data-v-542a7001] {
pointer-events: all;
position: fixed;
z-index: 1000;
}
.actionbar.is-docked[data-v-2e54db00] {
.actionbar.is-docked[data-v-542a7001] {
position: static;
border-style: none;
background-color: transparent;
padding: 0px;
}
.actionbar.is-dragging[data-v-2e54db00] {
.actionbar.is-dragging[data-v-542a7001] {
-webkit-user-select: none;
-moz-user-select: none;
user-select: none;
}
[data-v-2e54db00] .p-panel-content {
[data-v-542a7001] .p-panel-content {
padding: 0.25rem;
}
[data-v-2e54db00] .p-panel-header {
[data-v-542a7001] .p-panel-header {
display: none;
}
.comfyui-menu[data-v-ad2c662b] {
.comfyui-menu[data-v-d84a704d] {
width: 100vw;
background: var(--comfy-menu-bg);
color: var(--fg-color);
@@ -266,13 +250,13 @@
grid-column: 1/-1;
max-height: 90vh;
}
.comfyui-menu.dropzone[data-v-ad2c662b] {
.comfyui-menu.dropzone[data-v-d84a704d] {
background: var(--p-highlight-background);
}
.comfyui-menu.dropzone-active[data-v-ad2c662b] {
.comfyui-menu.dropzone-active[data-v-d84a704d] {
background: var(--p-highlight-background-focus);
}
.comfyui-logo[data-v-ad2c662b] {
.comfyui-logo[data-v-d84a704d] {
font-size: 1.2em;
-webkit-user-select: none;
-moz-user-select: none;

1074
web/assets/InstallView-C6UIhIu4.js generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/InstallView-C6UIhIu4.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

4
web/assets/InstallView-CN3CA9Fk.css generated vendored Normal file
View File

@@ -0,0 +1,4 @@
[data-v-53e62b05] .p-steppanel {
background-color: transparent
}

View File

@@ -1,8 +0,0 @@
[data-v-e5724e4d] .p-datatable-tbody > tr > td {
padding: 1px;
min-height: 2rem;
}
[data-v-e5724e4d] .p-datatable-row-selected .actions,[data-v-e5724e4d] .p-datatable-selectable-row:hover .actions {
visibility: visible;
}

8
web/assets/KeybindingPanel-C-7KE-Kw.css generated vendored Normal file
View File

@@ -0,0 +1,8 @@
[data-v-9d7e362e] .p-datatable-tbody > tr > td {
padding: 0.25rem;
min-height: 2rem
}
[data-v-9d7e362e] .p-datatable-row-selected .actions,[data-v-9d7e362e] .p-datatable-selectable-row:hover .actions {
visibility: visible
}

View File

@@ -1,264 +0,0 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, aw as createTextVNode, a6 as toDisplayString, z as unref, aA as script, j as createCommentVNode, r as ref, bN as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aJ as watchEffect, b9 as useToast, t as resolveDirective, bO as SearchBox, A as createBaseVNode, D as script$2, x as createBlock, ao as script$4, be as withModifiers, aH as script$6, v as withDirectives, P as pushScopeId, Q as popScopeId, bJ as KeyComboImpl, bP as KeybindingImpl, _ as _export_sfc } from "./index-BHayQCxv.js";
import { s as script$1, a as script$3, b as script$5 } from "./index-CwRXxFdA.js";
import "./index-C_wOqB0f.js";
const _hoisted_1$1 = {
key: 0,
class: "px-2"
};
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
__name: "KeyComboDisplay",
props: {
keyCombo: {},
isModified: { type: Boolean, default: false }
},
setup(__props) {
const props = __props;
const keySequences = computed(() => props.keyCombo.getKeySequences());
return (_ctx, _cache) => {
return openBlock(), createElementBlock("span", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(keySequences.value, (sequence, index) => {
return openBlock(), createElementBlock(Fragment, { key: index }, [
createVNode(unref(script), {
severity: _ctx.isModified ? "info" : "secondary"
}, {
default: withCtx(() => [
createTextVNode(toDisplayString(sequence), 1)
]),
_: 2
}, 1032, ["severity"]),
index < keySequences.value.length - 1 ? (openBlock(), createElementBlock("span", _hoisted_1$1, "+")) : createCommentVNode("", true)
], 64);
}), 128))
]);
};
}
});
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-e5724e4d"), n = n(), popScopeId(), n), "_withScopeId");
const _hoisted_1 = { class: "keybinding-panel" };
const _hoisted_2 = { class: "actions invisible" };
const _hoisted_3 = { key: 1 };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "KeybindingPanel",
setup(__props) {
const filters = ref({
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
});
const keybindingStore = useKeybindingStore();
const commandStore = useCommandStore();
const commandsData = computed(() => {
return Object.values(commandStore.commands).map((command) => ({
id: command.id,
keybinding: keybindingStore.getKeybindingByCommandId(command.id)
}));
});
const selectedCommandData = ref(null);
const editDialogVisible = ref(false);
const newBindingKeyCombo = ref(null);
const currentEditingCommand = ref(null);
const keybindingInput = ref(null);
const existingKeybindingOnCombo = computed(() => {
if (!currentEditingCommand.value) {
return null;
}
if (currentEditingCommand.value.keybinding?.combo?.equals(
newBindingKeyCombo.value
)) {
return null;
}
if (!newBindingKeyCombo.value) {
return null;
}
return keybindingStore.getKeybinding(newBindingKeyCombo.value);
});
function editKeybinding(commandData) {
currentEditingCommand.value = commandData;
newBindingKeyCombo.value = commandData.keybinding ? commandData.keybinding.combo : null;
editDialogVisible.value = true;
}
__name(editKeybinding, "editKeybinding");
watchEffect(() => {
if (editDialogVisible.value) {
setTimeout(() => {
keybindingInput.value?.$el?.focus();
}, 300);
}
});
function removeKeybinding(commandData) {
if (commandData.keybinding) {
keybindingStore.unsetKeybinding(commandData.keybinding);
keybindingStore.persistUserKeybindings();
}
}
__name(removeKeybinding, "removeKeybinding");
function captureKeybinding(event) {
const keyCombo = KeyComboImpl.fromEvent(event);
newBindingKeyCombo.value = keyCombo;
}
__name(captureKeybinding, "captureKeybinding");
function cancelEdit() {
editDialogVisible.value = false;
currentEditingCommand.value = null;
newBindingKeyCombo.value = null;
}
__name(cancelEdit, "cancelEdit");
function saveKeybinding() {
if (currentEditingCommand.value && newBindingKeyCombo.value) {
const updated = keybindingStore.updateKeybindingOnCommand(
new KeybindingImpl({
commandId: currentEditingCommand.value.id,
combo: newBindingKeyCombo.value
})
);
if (updated) {
keybindingStore.persistUserKeybindings();
}
}
cancelEdit();
}
__name(saveKeybinding, "saveKeybinding");
const toast = useToast();
async function resetKeybindings() {
keybindingStore.resetKeybindings();
await keybindingStore.persistUserKeybindings();
toast.add({
severity: "info",
summary: "Info",
detail: "Keybindings reset",
life: 3e3
});
}
__name(resetKeybindings, "resetKeybindings");
return (_ctx, _cache) => {
const _directive_tooltip = resolveDirective("tooltip");
return openBlock(), createElementBlock("div", _hoisted_1, [
createVNode(unref(script$3), {
value: commandsData.value,
selection: selectedCommandData.value,
"onUpdate:selection": _cache[1] || (_cache[1] = ($event) => selectedCommandData.value = $event),
"global-filter-fields": ["id"],
filters: filters.value,
selectionMode: "single",
stripedRows: "",
pt: {
header: "px-0"
}
}, {
header: withCtx(() => [
createVNode(SearchBox, {
modelValue: filters.value["global"].value,
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
placeholder: _ctx.$t("searchKeybindings") + "..."
}, null, 8, ["modelValue", "placeholder"])
]),
default: withCtx(() => [
createVNode(unref(script$1), {
field: "actions",
header: ""
}, {
body: withCtx((slotProps) => [
createBaseVNode("div", _hoisted_2, [
createVNode(unref(script$2), {
icon: "pi pi-pencil",
class: "p-button-text",
onClick: /* @__PURE__ */ __name(($event) => editKeybinding(slotProps.data), "onClick")
}, null, 8, ["onClick"]),
createVNode(unref(script$2), {
icon: "pi pi-trash",
class: "p-button-text p-button-danger",
onClick: /* @__PURE__ */ __name(($event) => removeKeybinding(slotProps.data), "onClick"),
disabled: !slotProps.data.keybinding
}, null, 8, ["onClick", "disabled"])
])
]),
_: 1
}),
createVNode(unref(script$1), {
field: "id",
header: "Command ID",
sortable: ""
}),
createVNode(unref(script$1), {
field: "keybinding",
header: "Keybinding"
}, {
body: withCtx((slotProps) => [
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
key: 0,
keyCombo: slotProps.data.keybinding.combo,
isModified: unref(keybindingStore).isCommandKeybindingModified(slotProps.data.id)
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
]),
_: 1
})
]),
_: 1
}, 8, ["value", "selection", "filters"]),
createVNode(unref(script$6), {
class: "min-w-96",
visible: editDialogVisible.value,
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => editDialogVisible.value = $event),
modal: "",
header: currentEditingCommand.value?.id,
onHide: cancelEdit
}, {
footer: withCtx(() => [
createVNode(unref(script$2), {
label: "Save",
icon: "pi pi-check",
onClick: saveKeybinding,
disabled: !!existingKeybindingOnCombo.value,
autofocus: ""
}, null, 8, ["disabled"])
]),
default: withCtx(() => [
createBaseVNode("div", null, [
createVNode(unref(script$4), {
class: "mb-2 text-center",
ref_key: "keybindingInput",
ref: keybindingInput,
modelValue: newBindingKeyCombo.value?.toString() ?? "",
placeholder: "Press keys for new binding",
onKeydown: withModifiers(captureKeybinding, ["stop", "prevent"]),
autocomplete: "off",
fluid: "",
invalid: !!existingKeybindingOnCombo.value
}, null, 8, ["modelValue", "invalid"]),
existingKeybindingOnCombo.value ? (openBlock(), createBlock(unref(script$5), {
key: 0,
severity: "error"
}, {
default: withCtx(() => [
createTextVNode(" Keybinding already exists on "),
createVNode(unref(script), {
severity: "secondary",
value: existingKeybindingOnCombo.value.commandId
}, null, 8, ["value"])
]),
_: 1
})) : createCommentVNode("", true)
])
]),
_: 1
}, 8, ["visible", "header"]),
withDirectives(createVNode(unref(script$2), {
class: "mt-4",
label: _ctx.$t("reset"),
icon: "pi pi-trash",
severity: "danger",
fluid: "",
text: "",
onClick: resetKeybindings
}, null, 8, ["label"]), [
[_directive_tooltip, _ctx.$t("resetKeybindingsTooltip")]
])
]);
};
}
});
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-e5724e4d"]]);
export {
KeybindingPanel as default
};
//# sourceMappingURL=KeybindingPanel-Dm_3sBT5.js.map

File diff suppressed because one or more lines are too long

279
web/assets/KeybindingPanel-lcJrxHwZ.js generated vendored Normal file
View File

@@ -0,0 +1,279 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, aw as createTextVNode, a6 as toDisplayString, z as unref, aA as script, j as createCommentVNode, r as ref, c6 as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aJ as watchEffect, bg as useToast, t as resolveDirective, x as createBlock, c7 as SearchBox, A as createBaseVNode, D as script$2, ao as script$4, bk as withModifiers, bT as script$5, aH as script$6, v as withDirectives, c8 as _sfc_main$2, P as pushScopeId, Q as popScopeId, c1 as KeyComboImpl, c9 as KeybindingImpl, _ as _export_sfc } from "./index-CoOvI8ZH.js";
import { s as script$1, a as script$3 } from "./index-DK6Kev7f.js";
import "./index-D4DWQPPQ.js";
const _hoisted_1$1 = {
key: 0,
class: "px-2"
};
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
__name: "KeyComboDisplay",
props: {
keyCombo: {},
isModified: { type: Boolean, default: false }
},
setup(__props) {
const props = __props;
const keySequences = computed(() => props.keyCombo.getKeySequences());
return (_ctx, _cache) => {
return openBlock(), createElementBlock("span", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(keySequences.value, (sequence, index) => {
return openBlock(), createElementBlock(Fragment, { key: index }, [
createVNode(unref(script), {
severity: _ctx.isModified ? "info" : "secondary"
}, {
default: withCtx(() => [
createTextVNode(toDisplayString(sequence), 1)
]),
_: 2
}, 1032, ["severity"]),
index < keySequences.value.length - 1 ? (openBlock(), createElementBlock("span", _hoisted_1$1, "+")) : createCommentVNode("", true)
], 64);
}), 128))
]);
};
}
});
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-9d7e362e"), n = n(), popScopeId(), n), "_withScopeId");
const _hoisted_1 = { class: "actions invisible flex flex-row" };
const _hoisted_2 = ["title"];
const _hoisted_3 = { key: 1 };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "KeybindingPanel",
setup(__props) {
const filters = ref({
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
});
const keybindingStore = useKeybindingStore();
const commandStore = useCommandStore();
const commandsData = computed(() => {
return Object.values(commandStore.commands).map((command) => ({
id: command.id,
keybinding: keybindingStore.getKeybindingByCommandId(command.id)
}));
});
const selectedCommandData = ref(null);
const editDialogVisible = ref(false);
const newBindingKeyCombo = ref(null);
const currentEditingCommand = ref(null);
const keybindingInput = ref(null);
const existingKeybindingOnCombo = computed(() => {
if (!currentEditingCommand.value) {
return null;
}
if (currentEditingCommand.value.keybinding?.combo?.equals(
newBindingKeyCombo.value
)) {
return null;
}
if (!newBindingKeyCombo.value) {
return null;
}
return keybindingStore.getKeybinding(newBindingKeyCombo.value);
});
function editKeybinding(commandData) {
currentEditingCommand.value = commandData;
newBindingKeyCombo.value = commandData.keybinding ? commandData.keybinding.combo : null;
editDialogVisible.value = true;
}
__name(editKeybinding, "editKeybinding");
watchEffect(() => {
if (editDialogVisible.value) {
setTimeout(() => {
keybindingInput.value?.$el?.focus();
}, 300);
}
});
function removeKeybinding(commandData) {
if (commandData.keybinding) {
keybindingStore.unsetKeybinding(commandData.keybinding);
keybindingStore.persistUserKeybindings();
}
}
__name(removeKeybinding, "removeKeybinding");
function captureKeybinding(event) {
const keyCombo = KeyComboImpl.fromEvent(event);
newBindingKeyCombo.value = keyCombo;
}
__name(captureKeybinding, "captureKeybinding");
function cancelEdit() {
editDialogVisible.value = false;
currentEditingCommand.value = null;
newBindingKeyCombo.value = null;
}
__name(cancelEdit, "cancelEdit");
function saveKeybinding() {
if (currentEditingCommand.value && newBindingKeyCombo.value) {
const updated = keybindingStore.updateKeybindingOnCommand(
new KeybindingImpl({
commandId: currentEditingCommand.value.id,
combo: newBindingKeyCombo.value
})
);
if (updated) {
keybindingStore.persistUserKeybindings();
}
}
cancelEdit();
}
__name(saveKeybinding, "saveKeybinding");
const toast = useToast();
async function resetKeybindings() {
keybindingStore.resetKeybindings();
await keybindingStore.persistUserKeybindings();
toast.add({
severity: "info",
summary: "Info",
detail: "Keybindings reset",
life: 3e3
});
}
__name(resetKeybindings, "resetKeybindings");
return (_ctx, _cache) => {
const _directive_tooltip = resolveDirective("tooltip");
return openBlock(), createBlock(_sfc_main$2, {
value: "Keybinding",
class: "keybinding-panel"
}, {
header: withCtx(() => [
createVNode(SearchBox, {
modelValue: filters.value["global"].value,
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
placeholder: _ctx.$t("searchKeybindings") + "..."
}, null, 8, ["modelValue", "placeholder"])
]),
default: withCtx(() => [
createVNode(unref(script$3), {
value: commandsData.value,
selection: selectedCommandData.value,
"onUpdate:selection": _cache[1] || (_cache[1] = ($event) => selectedCommandData.value = $event),
"global-filter-fields": ["id"],
filters: filters.value,
selectionMode: "single",
stripedRows: "",
pt: {
header: "px-0"
}
}, {
default: withCtx(() => [
createVNode(unref(script$1), {
field: "actions",
header: ""
}, {
body: withCtx((slotProps) => [
createBaseVNode("div", _hoisted_1, [
createVNode(unref(script$2), {
icon: "pi pi-pencil",
class: "p-button-text",
onClick: /* @__PURE__ */ __name(($event) => editKeybinding(slotProps.data), "onClick")
}, null, 8, ["onClick"]),
createVNode(unref(script$2), {
icon: "pi pi-trash",
class: "p-button-text p-button-danger",
onClick: /* @__PURE__ */ __name(($event) => removeKeybinding(slotProps.data), "onClick"),
disabled: !slotProps.data.keybinding
}, null, 8, ["onClick", "disabled"])
])
]),
_: 1
}),
createVNode(unref(script$1), {
field: "id",
header: "Command ID",
sortable: "",
class: "max-w-64 2xl:max-w-full"
}, {
body: withCtx((slotProps) => [
createBaseVNode("div", {
class: "overflow-hidden text-ellipsis whitespace-nowrap",
title: slotProps.data.id
}, toDisplayString(slotProps.data.id), 9, _hoisted_2)
]),
_: 1
}),
createVNode(unref(script$1), {
field: "keybinding",
header: "Keybinding"
}, {
body: withCtx((slotProps) => [
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
key: 0,
keyCombo: slotProps.data.keybinding.combo,
isModified: unref(keybindingStore).isCommandKeybindingModified(slotProps.data.id)
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
]),
_: 1
})
]),
_: 1
}, 8, ["value", "selection", "filters"]),
createVNode(unref(script$6), {
class: "min-w-96",
visible: editDialogVisible.value,
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => editDialogVisible.value = $event),
modal: "",
header: currentEditingCommand.value?.id,
onHide: cancelEdit
}, {
footer: withCtx(() => [
createVNode(unref(script$2), {
label: "Save",
icon: "pi pi-check",
onClick: saveKeybinding,
disabled: !!existingKeybindingOnCombo.value,
autofocus: ""
}, null, 8, ["disabled"])
]),
default: withCtx(() => [
createBaseVNode("div", null, [
createVNode(unref(script$4), {
class: "mb-2 text-center",
ref_key: "keybindingInput",
ref: keybindingInput,
modelValue: newBindingKeyCombo.value?.toString() ?? "",
placeholder: "Press keys for new binding",
onKeydown: withModifiers(captureKeybinding, ["stop", "prevent"]),
autocomplete: "off",
fluid: "",
invalid: !!existingKeybindingOnCombo.value
}, null, 8, ["modelValue", "invalid"]),
existingKeybindingOnCombo.value ? (openBlock(), createBlock(unref(script$5), {
key: 0,
severity: "error"
}, {
default: withCtx(() => [
createTextVNode(" Keybinding already exists on "),
createVNode(unref(script), {
severity: "secondary",
value: existingKeybindingOnCombo.value.commandId
}, null, 8, ["value"])
]),
_: 1
})) : createCommentVNode("", true)
])
]),
_: 1
}, 8, ["visible", "header"]),
withDirectives(createVNode(unref(script$2), {
class: "mt-4",
label: _ctx.$t("reset"),
icon: "pi pi-trash",
severity: "danger",
fluid: "",
text: "",
onClick: resetKeybindings
}, null, 8, ["label"]), [
[_directive_tooltip, _ctx.$t("resetKeybindingsTooltip")]
])
]),
_: 1
});
};
}
});
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-9d7e362e"]]);
export {
KeybindingPanel as default
};
//# sourceMappingURL=KeybindingPanel-lcJrxHwZ.js.map

1
web/assets/KeybindingPanel-lcJrxHwZ.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

150
web/assets/ServerConfigPanel-x68ubY-c.js generated vendored Normal file
View File

@@ -0,0 +1,150 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { A as createBaseVNode, g as openBlock, h as createElementBlock, aU as markRaw, d as defineComponent, u as useSettingStore, bw as storeToRefs, w as watch, cy as useCopyToClipboard, x as createBlock, y as withCtx, z as unref, bT as script, a6 as toDisplayString, O as renderList, N as Fragment, i as createVNode, D as script$1, j as createCommentVNode, bI as script$2, cz as formatCamelCase, cA as FormItem, c8 as _sfc_main$1, bN as electronAPI } from "./index-CoOvI8ZH.js";
import { u as useServerConfigStore } from "./serverConfigStore-cctR8PGG.js";
const _hoisted_1$1 = {
viewBox: "0 0 24 24",
width: "1.2em",
height: "1.2em"
};
const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("path", {
fill: "none",
stroke: "currentColor",
"stroke-linecap": "round",
"stroke-linejoin": "round",
"stroke-width": "2",
d: "m4 17l6-6l-6-6m8 14h8"
}, null, -1);
const _hoisted_3$1 = [
_hoisted_2$1
];
function render(_ctx, _cache) {
return openBlock(), createElementBlock("svg", _hoisted_1$1, [..._hoisted_3$1]);
}
__name(render, "render");
const __unplugin_components_0 = markRaw({ name: "lucide-terminal", render });
const _hoisted_1 = { class: "flex flex-col gap-2" };
const _hoisted_2 = { class: "flex justify-end gap-2" };
const _hoisted_3 = { class: "flex items-center justify-between" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ServerConfigPanel",
setup(__props) {
const settingStore = useSettingStore();
const serverConfigStore = useServerConfigStore();
const {
serverConfigsByCategory,
serverConfigValues,
launchArgs,
commandLineArgs,
modifiedConfigs
} = storeToRefs(serverConfigStore);
const revertChanges = /* @__PURE__ */ __name(() => {
serverConfigStore.revertChanges();
}, "revertChanges");
const restartApp = /* @__PURE__ */ __name(() => {
electronAPI().restartApp();
}, "restartApp");
watch(launchArgs, (newVal) => {
settingStore.set("Comfy.Server.LaunchArgs", newVal);
});
watch(serverConfigValues, (newVal) => {
settingStore.set("Comfy.Server.ServerConfigValues", newVal);
});
const { copyToClipboard } = useCopyToClipboard();
const copyCommandLineArgs = /* @__PURE__ */ __name(async () => {
await copyToClipboard(commandLineArgs.value);
}, "copyCommandLineArgs");
return (_ctx, _cache) => {
const _component_i_lucide58terminal = __unplugin_components_0;
return openBlock(), createBlock(_sfc_main$1, {
value: "Server-Config",
class: "server-config-panel"
}, {
header: withCtx(() => [
createBaseVNode("div", _hoisted_1, [
unref(modifiedConfigs).length > 0 ? (openBlock(), createBlock(unref(script), {
key: 0,
severity: "info",
"pt:text": "w-full"
}, {
default: withCtx(() => [
createBaseVNode("p", null, toDisplayString(_ctx.$t("serverConfig.modifiedConfigs")), 1),
createBaseVNode("ul", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(unref(modifiedConfigs), (config) => {
return openBlock(), createElementBlock("li", {
key: config.id
}, toDisplayString(config.name) + ": " + toDisplayString(config.initialValue) + " → " + toDisplayString(config.value), 1);
}), 128))
]),
createBaseVNode("div", _hoisted_2, [
createVNode(unref(script$1), {
label: _ctx.$t("serverConfig.revertChanges"),
onClick: revertChanges,
outlined: ""
}, null, 8, ["label"]),
createVNode(unref(script$1), {
label: _ctx.$t("serverConfig.restart"),
onClick: restartApp,
outlined: "",
severity: "danger"
}, null, 8, ["label"])
])
]),
_: 1
})) : createCommentVNode("", true),
unref(commandLineArgs) ? (openBlock(), createBlock(unref(script), {
key: 1,
severity: "secondary",
"pt:text": "w-full"
}, {
icon: withCtx(() => [
createVNode(_component_i_lucide58terminal, { class: "text-xl font-bold" })
]),
default: withCtx(() => [
createBaseVNode("div", _hoisted_3, [
createBaseVNode("p", null, toDisplayString(unref(commandLineArgs)), 1),
createVNode(unref(script$1), {
icon: "pi pi-clipboard",
onClick: copyCommandLineArgs,
severity: "secondary",
text: ""
})
])
]),
_: 1
})) : createCommentVNode("", true)
])
]),
default: withCtx(() => [
(openBlock(true), createElementBlock(Fragment, null, renderList(Object.entries(unref(serverConfigsByCategory)), ([label, items], i) => {
return openBlock(), createElementBlock("div", { key: label }, [
i > 0 ? (openBlock(), createBlock(unref(script$2), { key: 0 })) : createCommentVNode("", true),
createBaseVNode("h3", null, toDisplayString(unref(formatCamelCase)(label)), 1),
(openBlock(true), createElementBlock(Fragment, null, renderList(items, (item) => {
return openBlock(), createElementBlock("div", {
key: item.name,
class: "flex items-center mb-4"
}, [
createVNode(FormItem, {
item,
formValue: item.value,
"onUpdate:formValue": /* @__PURE__ */ __name(($event) => item.value = $event, "onUpdate:formValue"),
id: item.id,
labelClass: {
"text-highlight": item.initialValue !== item.value
}
}, null, 8, ["item", "formValue", "onUpdate:formValue", "id", "labelClass"])
]);
}), 128))
]);
}), 128))
]),
_: 1
});
};
}
});
export {
_sfc_main as default
};
//# sourceMappingURL=ServerConfigPanel-x68ubY-c.js.map

1
web/assets/ServerConfigPanel-x68ubY-c.js.map generated vendored Normal file
View File

@@ -0,0 +1 @@
{"version":3,"file":"ServerConfigPanel-x68ubY-c.js","sources":["../../src/components/dialog/content/setting/ServerConfigPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Server-Config\" class=\"server-config-panel\">\n <template #header>\n <div class=\"flex flex-col gap-2\">\n <Message\n v-if=\"modifiedConfigs.length > 0\"\n severity=\"info\"\n pt:text=\"w-full\"\n >\n <p>\n {{ $t('serverConfig.modifiedConfigs') }}\n </p>\n <ul>\n <li v-for=\"config in modifiedConfigs\" :key=\"config.id\">\n {{ config.name }}: {{ config.initialValue }} → {{ config.value }}\n </li>\n </ul>\n <div class=\"flex justify-end gap-2\">\n <Button\n :label=\"$t('serverConfig.revertChanges')\"\n @click=\"revertChanges\"\n outlined\n />\n <Button\n :label=\"$t('serverConfig.restart')\"\n @click=\"restartApp\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n <Message v-if=\"commandLineArgs\" severity=\"secondary\" pt:text=\"w-full\">\n <template #icon>\n <i-lucide:terminal class=\"text-xl font-bold\" />\n </template>\n <div class=\"flex items-center justify-between\">\n <p>{{ commandLineArgs }}</p>\n <Button\n icon=\"pi pi-clipboard\"\n @click=\"copyCommandLineArgs\"\n severity=\"secondary\"\n text\n />\n </div>\n </Message>\n </div>\n </template>\n <div\n v-for=\"([label, items], i) in Object.entries(serverConfigsByCategory)\"\n :key=\"label\"\n >\n <Divider v-if=\"i > 0\" />\n <h3>{{ formatCamelCase(label) }}</h3>\n <div\n v-for=\"item in items\"\n :key=\"item.name\"\n class=\"flex items-center mb-4\"\n >\n <FormItem\n :item=\"item\"\n v-model:formValue=\"item.value\"\n :id=\"item.id\"\n :labelClass=\"{\n 'text-highlight': item.initialValue !== item.value\n }\"\n />\n </div>\n </div>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport Divider from 'primevue/divider'\nimport FormItem from '@/components/common/FormItem.vue'\nimport PanelTemplate from './PanelTemplate.vue'\nimport { formatCamelCase } from '@/utils/formatUtil'\nimport { useServerConfigStore } from '@/stores/serverConfigStore'\nimport { storeToRefs } from 'pinia'\nimport { electronAPI } from '@/utils/envUtil'\nimport { useSettingStore } from '@/stores/settingStore'\nimport { watch } from 'vue'\nimport { useCopyToClipboard } from '@/hooks/clipboardHooks'\n\nconst settingStore = useSettingStore()\nconst serverConfigStore = useServerConfigStore()\nconst {\n serverConfigsByCategory,\n serverConfigValues,\n launchArgs,\n commandLineArgs,\n modifiedConfigs\n} = storeToRefs(serverConfigStore)\n\nconst revertChanges = () => {\n serverConfigStore.revertChanges()\n}\n\nconst restartApp = () => {\n electronAPI().restartApp()\n}\n\nwatch(launchArgs, (newVal) => {\n settingStore.set('Comfy.Server.LaunchArgs', newVal)\n})\n\nwatch(serverConfigValues, (newVal) => {\n settingStore.set('Comfy.Server.ServerConfigValues', newVal)\n})\n\nconst { copyToClipboard } = useCopyToClipboard()\nconst copyCommandLineArgs = async () => {\n await copyToClipboard(commandLineArgs.value)\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAqFA,UAAM,eAAe;AACrB,UAAM,oBAAoB;AACpB,UAAA;AAAA,MACJ;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,IAAA,IACE,YAAY,iBAAiB;AAEjC,UAAM,gBAAgB,6BAAM;AAC1B,wBAAkB,cAAc;AAAA,IAAA,GADZ;AAItB,UAAM,aAAa,6BAAM;AACvB,kBAAA,EAAc;IAAW,GADR;AAIb,UAAA,YAAY,CAAC,WAAW;AACf,mBAAA,IAAI,2BAA2B,MAAM;AAAA,IAAA,CACnD;AAEK,UAAA,oBAAoB,CAAC,WAAW;AACvB,mBAAA,IAAI,mCAAmC,MAAM;AAAA,IAAA,CAC3D;AAEK,UAAA,EAAE,oBAAoB;AAC5B,UAAM,sBAAsB,mCAAY;AAChC,YAAA,gBAAgB,gBAAgB,KAAK;AAAA,IAAA,GADjB;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

79
web/assets/ServerStartView-CqRVtr1h.js generated vendored Normal file
View File

@@ -0,0 +1,79 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, aD as useI18n, r as ref, o as onMounted, g as openBlock, h as createElementBlock, A as createBaseVNode, aw as createTextVNode, a6 as toDisplayString, z as unref, j as createCommentVNode, i as createVNode, D as script, bM as BaseTerminal, P as pushScopeId, Q as popScopeId, bN as electronAPI, _ as _export_sfc } from "./index-CoOvI8ZH.js";
import { P as ProgressStatus } from "./index-BppSBmxJ.js";
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-f5429be7"), n = n(), popScopeId(), n), "_withScopeId");
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
const _hoisted_2 = { class: "text-2xl font-bold" };
const _hoisted_3 = { key: 0 };
const _hoisted_4 = {
key: 0,
class: "flex items-center my-4 gap-2"
};
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ServerStartView",
setup(__props) {
const electron = electronAPI();
const { t } = useI18n();
const status = ref(ProgressStatus.INITIAL_STATE);
const electronVersion = ref("");
let xterm;
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
status.value = newStatus;
xterm?.clear();
}, "updateProgress");
const terminalCreated = /* @__PURE__ */ __name(({ terminal, useAutoSize }, root) => {
xterm = terminal;
useAutoSize(root, true, true);
electron.onLogMessage((message) => {
terminal.write(message);
});
terminal.options.cursorBlink = false;
terminal.options.disableStdin = true;
terminal.options.cursorInactiveStyle = "block";
}, "terminalCreated");
const reinstall = /* @__PURE__ */ __name(() => electron.reinstall(), "reinstall");
const reportIssue = /* @__PURE__ */ __name(() => {
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
}, "reportIssue");
const openLogs = /* @__PURE__ */ __name(() => electron.openLogsFolder(), "openLogs");
onMounted(async () => {
electron.sendReady();
electron.onProgressUpdate(updateProgress);
electronVersion.value = await electron.getElectronVersion();
});
return (_ctx, _cache) => {
return openBlock(), createElementBlock("div", _hoisted_1, [
createBaseVNode("h2", _hoisted_2, [
createTextVNode(toDisplayString(unref(t)(`serverStart.process.${status.value}`)) + " ", 1),
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("span", _hoisted_3, " v" + toDisplayString(electronVersion.value), 1)) : createCommentVNode("", true)
]),
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("div", _hoisted_4, [
createVNode(unref(script), {
icon: "pi pi-flag",
severity: "secondary",
label: unref(t)("serverStart.reportIssue"),
onClick: reportIssue
}, null, 8, ["label"]),
createVNode(unref(script), {
icon: "pi pi-file",
severity: "secondary",
label: unref(t)("serverStart.openLogs"),
onClick: openLogs
}, null, 8, ["label"]),
createVNode(unref(script), {
icon: "pi pi-refresh",
label: unref(t)("serverStart.reinstall"),
onClick: reinstall
}, null, 8, ["label"])
])) : createCommentVNode("", true),
createVNode(BaseTerminal, { onCreated: terminalCreated })
]);
};
}
});
const ServerStartView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-f5429be7"]]);
export {
ServerStartView as default
};
//# sourceMappingURL=ServerStartView-CqRVtr1h.js.map

1
web/assets/ServerStartView-CqRVtr1h.js.map generated vendored Normal file
View File

@@ -0,0 +1 @@
{"version":3,"file":"ServerStartView-CqRVtr1h.js","sources":["../../src/views/ServerStartView.vue"],"sourcesContent":["<template>\n <div\n class=\"font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto\"\n >\n <h2 class=\"text-2xl font-bold\">\n {{ t(`serverStart.process.${status}`) }}\n <span v-if=\"status === ProgressStatus.ERROR\">\n v{{ electronVersion }}\n </span>\n </h2>\n <div\n v-if=\"status === ProgressStatus.ERROR\"\n class=\"flex items-center my-4 gap-2\"\n >\n <Button\n icon=\"pi pi-flag\"\n severity=\"secondary\"\n :label=\"t('serverStart.reportIssue')\"\n @click=\"reportIssue\"\n />\n <Button\n icon=\"pi pi-file\"\n severity=\"secondary\"\n :label=\"t('serverStart.openLogs')\"\n @click=\"openLogs\"\n />\n <Button\n icon=\"pi pi-refresh\"\n :label=\"t('serverStart.reinstall')\"\n @click=\"reinstall\"\n />\n </div>\n <BaseTerminal @created=\"terminalCreated\" />\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { ref, onMounted, Ref } from 'vue'\nimport BaseTerminal from '@/components/bottomPanel/tabs/terminal/BaseTerminal.vue'\nimport { ProgressStatus } from '@comfyorg/comfyui-electron-types'\nimport { electronAPI } from '@/utils/envUtil'\nimport type { useTerminal } from '@/hooks/bottomPanelTabs/useTerminal'\nimport { Terminal } from '@xterm/xterm'\nimport { useI18n } from 'vue-i18n'\n\nconst electron = electronAPI()\nconst { t } = useI18n()\n\nconst status = ref<ProgressStatus>(ProgressStatus.INITIAL_STATE)\nconst electronVersion = ref<string>('')\nlet xterm: Terminal | undefined\n\nconst updateProgress = ({ status: newStatus }: { status: ProgressStatus }) => {\n status.value = newStatus\n xterm?.clear()\n}\n\nconst terminalCreated = (\n { terminal, useAutoSize }: ReturnType<typeof useTerminal>,\n root: Ref<HTMLElement>\n) => {\n xterm = terminal\n\n useAutoSize(root, true, true)\n electron.onLogMessage((message: string) => {\n terminal.write(message)\n })\n\n terminal.options.cursorBlink = false\n terminal.options.disableStdin = true\n terminal.options.cursorInactiveStyle = 'block'\n}\n\nconst reinstall = () => electron.reinstall()\nconst reportIssue = () => {\n window.open('https://forum.comfy.org/c/v1-feedback/', '_blank')\n}\nconst openLogs = () => electron.openLogsFolder()\n\nonMounted(async () => {\n electron.sendReady()\n electron.onProgressUpdate(updateProgress)\n electronVersion.value = await electron.getElectronVersion()\n})\n</script>\n\n<style scoped>\n:deep(.xterm-helper-textarea) {\n /* Hide this as it moves all over when uv is running */\n display: none;\n}\n</style>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;AA8CA,UAAM,WAAW;AACX,UAAA,EAAE,MAAM;AAER,UAAA,SAAS,IAAoB,eAAe,aAAa;AACzD,UAAA,kBAAkB,IAAY,EAAE;AAClC,QAAA;AAEJ,UAAM,iBAAiB,wBAAC,EAAE,QAAQ,gBAA4C;AAC5E,aAAO,QAAQ;AACf,aAAO,MAAM;AAAA,IAAA,GAFQ;AAKvB,UAAM,kBAAkB,wBACtB,EAAE,UAAU,YAAA,GACZ,SACG;AACK,cAAA;AAEI,kBAAA,MAAM,MAAM,IAAI;AACnB,eAAA,aAAa,CAAC,YAAoB;AACzC,iBAAS,MAAM,OAAO;AAAA,MAAA,CACvB;AAED,eAAS,QAAQ,cAAc;AAC/B,eAAS,QAAQ,eAAe;AAChC,eAAS,QAAQ,sBAAsB;AAAA,IAAA,GAbjB;AAgBlB,UAAA,YAAY,6BAAM,SAAS,aAAf;AAClB,UAAM,cAAc,6BAAM;AACjB,aAAA,KAAK,0CAA0C,QAAQ;AAAA,IAAA,GAD5C;AAGd,UAAA,WAAW,6BAAM,SAAS,kBAAf;AAEjB,cAAU,YAAY;AACpB,eAAS,UAAU;AACnB,eAAS,iBAAiB,cAAc;AACxB,sBAAA,QAAQ,MAAM,SAAS,mBAAmB;AAAA,IAAA,CAC3D;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

5
web/assets/ServerStartView-Djq8v91B.css generated vendored Normal file
View File

@@ -0,0 +1,5 @@
[data-v-f5429be7] .xterm-helper-textarea {
/* Hide this as it moves all over when uv is running */
display: none;
}

33
web/assets/WelcomeView-C4D1cggT.js generated vendored Normal file
View File

@@ -0,0 +1,33 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, g as openBlock, h as createElementBlock, A as createBaseVNode, a6 as toDisplayString, i as createVNode, z as unref, D as script, P as pushScopeId, Q as popScopeId, _ as _export_sfc } from "./index-CoOvI8ZH.js";
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-12b8b11b"), n = n(), popScopeId(), n), "_withScopeId");
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
const _hoisted_2 = { class: "flex flex-col items-center justify-center gap-8 p-8" };
const _hoisted_3 = { class: "animated-gradient-text text-glow select-none" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "WelcomeView",
setup(__props) {
return (_ctx, _cache) => {
return openBlock(), createElementBlock("div", _hoisted_1, [
createBaseVNode("div", _hoisted_2, [
createBaseVNode("h1", _hoisted_3, toDisplayString(_ctx.$t("welcome.title")), 1),
createVNode(unref(script), {
label: _ctx.$t("welcome.getStarted"),
icon: "pi pi-arrow-right",
iconPos: "right",
size: "large",
rounded: "",
onClick: _cache[0] || (_cache[0] = ($event) => _ctx.$router.push("/install")),
class: "p-4 text-lg fade-in-up"
}, null, 8, ["label"])
])
]);
};
}
});
const WelcomeView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-12b8b11b"]]);
export {
WelcomeView as default
};
//# sourceMappingURL=WelcomeView-C4D1cggT.js.map

1
web/assets/WelcomeView-C4D1cggT.js.map generated vendored Normal file
View File

@@ -0,0 +1 @@
{"version":3,"file":"WelcomeView-C4D1cggT.js","sources":[],"sourcesContent":[],"names":[],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

36
web/assets/WelcomeView-DQQgHnsr.css generated vendored Normal file
View File

@@ -0,0 +1,36 @@
.animated-gradient-text[data-v-12b8b11b] {
font-weight: 700;
font-size: clamp(2rem, 8vw, 4rem);
background: linear-gradient(to right, #12c2e9, #c471ed, #f64f59, #12c2e9);
background-size: 300% auto;
background-clip: text;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
animation: gradient-12b8b11b 8s linear infinite;
}
.text-glow[data-v-12b8b11b] {
filter: drop-shadow(0 0 8px rgba(255, 255, 255, 0.3));
}
@keyframes gradient-12b8b11b {
0% {
background-position: 0% center;
}
100% {
background-position: 300% center;
}
}
.fade-in-up[data-v-12b8b11b] {
animation: fadeInUp-12b8b11b 1.5s ease-out;
animation-fill-mode: both;
}
@keyframes fadeInUp-12b8b11b {
0% {
opacity: 0;
transform: translateY(20px);
}
100% {
opacity: 1;
transform: translateY(0);
}
}

1
web/assets/index-BHayQCxv.js.map generated vendored

File diff suppressed because one or more lines are too long

4636
web/assets/index-BReiUkk9.js generated vendored

File diff suppressed because it is too large Load Diff

1
web/assets/index-BReiUkk9.js.map generated vendored

File diff suppressed because one or more lines are too long

52310
web/assets/index-Ba7IybyO.js generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/index-Ba7IybyO.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

106
web/assets/index-BppSBmxJ.js generated vendored Normal file
View File

@@ -0,0 +1,106 @@
const IPC_CHANNELS = {
LOADING_PROGRESS: "loading-progress",
IS_PACKAGED: "is-packaged",
RENDERER_READY: "renderer-ready",
RESTART_APP: "restart-app",
REINSTALL: "reinstall",
LOG_MESSAGE: "log-message",
OPEN_DIALOG: "open-dialog",
DOWNLOAD_PROGRESS: "download-progress",
START_DOWNLOAD: "start-download",
PAUSE_DOWNLOAD: "pause-download",
RESUME_DOWNLOAD: "resume-download",
CANCEL_DOWNLOAD: "cancel-download",
DELETE_MODEL: "delete-model",
GET_ALL_DOWNLOADS: "get-all-downloads",
GET_ELECTRON_VERSION: "get-electron-version",
SEND_ERROR_TO_SENTRY: "send-error-to-sentry",
GET_BASE_PATH: "get-base-path",
GET_MODEL_CONFIG_PATH: "get-model-config-path",
OPEN_PATH: "open-path",
OPEN_LOGS_PATH: "open-logs-path",
OPEN_DEV_TOOLS: "open-dev-tools",
TERMINAL_WRITE: "execute-terminal-command",
TERMINAL_RESIZE: "resize-terminal",
TERMINAL_RESTORE: "restore-terminal",
TERMINAL_ON_OUTPUT: "terminal-output",
IS_FIRST_TIME_SETUP: "is-first-time-setup",
GET_SYSTEM_PATHS: "get-system-paths",
VALIDATE_INSTALL_PATH: "validate-install-path",
VALIDATE_COMFYUI_SOURCE: "validate-comfyui-source",
SHOW_DIRECTORY_PICKER: "show-directory-picker",
INSTALL_COMFYUI: "install-comfyui"
};
var ProgressStatus = /* @__PURE__ */ ((ProgressStatus2) => {
ProgressStatus2["INITIAL_STATE"] = "initial-state";
ProgressStatus2["PYTHON_SETUP"] = "python-setup";
ProgressStatus2["STARTING_SERVER"] = "starting-server";
ProgressStatus2["READY"] = "ready";
ProgressStatus2["ERROR"] = "error";
return ProgressStatus2;
})(ProgressStatus || {});
const ProgressMessages = {
[
"initial-state"
/* INITIAL_STATE */
]: "Loading...",
[
"python-setup"
/* PYTHON_SETUP */
]: "Setting up Python Environment...",
[
"starting-server"
/* STARTING_SERVER */
]: "Starting ComfyUI server...",
[
"ready"
/* READY */
]: "Finishing...",
[
"error"
/* ERROR */
]: "Was not able to start ComfyUI. Please check the logs for more details. You can open it from the Help menu. Please report issues to: https://forum.comfy.org"
};
const ELECTRON_BRIDGE_API = "electronAPI";
const SENTRY_URL_ENDPOINT = "https://942cadba58d247c9cab96f45221aa813@o4507954455314432.ingest.us.sentry.io/4508007940685824";
const MigrationItems = [
{
id: "user_files",
label: "User Files",
description: "Settings and user-created workflows"
},
{
id: "models",
label: "Models",
description: "Reference model files from existing ComfyUI installations. (No copy)"
}
// TODO: Decide whether we want to auto-migrate custom nodes, and install their dependencies.
// huchenlei: This is a very essential thing for migration experience.
// {
// id: 'custom_nodes',
// label: 'Custom Nodes',
// description: 'Reference custom node files from existing ComfyUI installations. (No copy)',
// },
];
const DEFAULT_SERVER_ARGS = {
/** The host to use for the ComfyUI server. */
host: "127.0.0.1",
/** The port to use for the ComfyUI server. */
port: 8e3,
// Extra arguments to pass to the ComfyUI server.
extraServerArgs: {}
};
var DownloadStatus = /* @__PURE__ */ ((DownloadStatus2) => {
DownloadStatus2["PENDING"] = "pending";
DownloadStatus2["IN_PROGRESS"] = "in_progress";
DownloadStatus2["COMPLETED"] = "completed";
DownloadStatus2["PAUSED"] = "paused";
DownloadStatus2["ERROR"] = "error";
DownloadStatus2["CANCELLED"] = "cancelled";
return DownloadStatus2;
})(DownloadStatus || {});
export {
MigrationItems as M,
ProgressStatus as P
};
//# sourceMappingURL=index-BppSBmxJ.js.map

1
web/assets/index-BppSBmxJ.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

102
web/assets/index-C_wOqB0f.js generated vendored
View File

@@ -1,102 +0,0 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { bS as script$4, A as createBaseVNode, g as openBlock, h as createElementBlock, m as mergeProps } from "./index-BHayQCxv.js";
var script$3 = {
name: "BarsIcon",
"extends": script$4
};
var _hoisted_1$3 = /* @__PURE__ */ createBaseVNode("path", {
"fill-rule": "evenodd",
"clip-rule": "evenodd",
d: "M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z",
fill: "currentColor"
}, null, -1);
var _hoisted_2$3 = [_hoisted_1$3];
function render$3(_ctx, _cache, $props, $setup, $data, $options) {
return openBlock(), createElementBlock("svg", mergeProps({
width: "14",
height: "14",
viewBox: "0 0 14 14",
fill: "none",
xmlns: "http://www.w3.org/2000/svg"
}, _ctx.pti()), _hoisted_2$3, 16);
}
__name(render$3, "render$3");
script$3.render = render$3;
var script$2 = {
name: "PlusIcon",
"extends": script$4
};
var _hoisted_1$2 = /* @__PURE__ */ createBaseVNode("path", {
d: "M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z",
fill: "currentColor"
}, null, -1);
var _hoisted_2$2 = [_hoisted_1$2];
function render$2(_ctx, _cache, $props, $setup, $data, $options) {
return openBlock(), createElementBlock("svg", mergeProps({
width: "14",
height: "14",
viewBox: "0 0 14 14",
fill: "none",
xmlns: "http://www.w3.org/2000/svg"
}, _ctx.pti()), _hoisted_2$2, 16);
}
__name(render$2, "render$2");
script$2.render = render$2;
var script$1 = {
name: "ExclamationTriangleIcon",
"extends": script$4
};
var _hoisted_1$1 = /* @__PURE__ */ createBaseVNode("path", {
d: "M13.4018 13.1893H0.598161C0.49329 13.189 0.390283 13.1615 0.299143 13.1097C0.208003 13.0578 0.131826 12.9832 0.0780112 12.8932C0.0268539 12.8015 0 12.6982 0 12.5931C0 12.4881 0.0268539 12.3848 0.0780112 12.293L6.47985 1.08982C6.53679 1.00399 6.61408 0.933574 6.70484 0.884867C6.7956 0.836159 6.897 0.810669 7 0.810669C7.103 0.810669 7.2044 0.836159 7.29516 0.884867C7.38592 0.933574 7.46321 1.00399 7.52015 1.08982L13.922 12.293C13.9731 12.3848 14 12.4881 14 12.5931C14 12.6982 13.9731 12.8015 13.922 12.8932C13.8682 12.9832 13.792 13.0578 13.7009 13.1097C13.6097 13.1615 13.5067 13.189 13.4018 13.1893ZM1.63046 11.989H12.3695L7 2.59425L1.63046 11.989Z",
fill: "currentColor"
}, null, -1);
var _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("path", {
d: "M6.99996 8.78801C6.84143 8.78594 6.68997 8.72204 6.57787 8.60993C6.46576 8.49782 6.40186 8.34637 6.39979 8.18784V5.38703C6.39979 5.22786 6.46302 5.0752 6.57557 4.96265C6.68813 4.85009 6.84078 4.78686 6.99996 4.78686C7.15914 4.78686 7.31179 4.85009 7.42435 4.96265C7.5369 5.0752 7.60013 5.22786 7.60013 5.38703V8.18784C7.59806 8.34637 7.53416 8.49782 7.42205 8.60993C7.30995 8.72204 7.15849 8.78594 6.99996 8.78801Z",
fill: "currentColor"
}, null, -1);
var _hoisted_3 = /* @__PURE__ */ createBaseVNode("path", {
d: "M6.99996 11.1887C6.84143 11.1866 6.68997 11.1227 6.57787 11.0106C6.46576 10.8985 6.40186 10.7471 6.39979 10.5885V10.1884C6.39979 10.0292 6.46302 9.87658 6.57557 9.76403C6.68813 9.65147 6.84078 9.58824 6.99996 9.58824C7.15914 9.58824 7.31179 9.65147 7.42435 9.76403C7.5369 9.87658 7.60013 10.0292 7.60013 10.1884V10.5885C7.59806 10.7471 7.53416 10.8985 7.42205 11.0106C7.30995 11.1227 7.15849 11.1866 6.99996 11.1887Z",
fill: "currentColor"
}, null, -1);
var _hoisted_4 = [_hoisted_1$1, _hoisted_2$1, _hoisted_3];
function render$1(_ctx, _cache, $props, $setup, $data, $options) {
return openBlock(), createElementBlock("svg", mergeProps({
width: "14",
height: "14",
viewBox: "0 0 14 14",
fill: "none",
xmlns: "http://www.w3.org/2000/svg"
}, _ctx.pti()), _hoisted_4, 16);
}
__name(render$1, "render$1");
script$1.render = render$1;
var script = {
name: "InfoCircleIcon",
"extends": script$4
};
var _hoisted_1 = /* @__PURE__ */ createBaseVNode("path", {
"fill-rule": "evenodd",
"clip-rule": "evenodd",
d: "M3.11101 12.8203C4.26215 13.5895 5.61553 14 7 14C8.85652 14 10.637 13.2625 11.9497 11.9497C13.2625 10.637 14 8.85652 14 7C14 5.61553 13.5895 4.26215 12.8203 3.11101C12.0511 1.95987 10.9579 1.06266 9.67879 0.532846C8.3997 0.00303296 6.99224 -0.13559 5.63437 0.134506C4.2765 0.404603 3.02922 1.07129 2.05026 2.05026C1.07129 3.02922 0.404603 4.2765 0.134506 5.63437C-0.13559 6.99224 0.00303296 8.3997 0.532846 9.67879C1.06266 10.9579 1.95987 12.0511 3.11101 12.8203ZM3.75918 2.14976C4.71846 1.50879 5.84628 1.16667 7 1.16667C8.5471 1.16667 10.0308 1.78125 11.1248 2.87521C12.2188 3.96918 12.8333 5.45291 12.8333 7C12.8333 8.15373 12.4912 9.28154 11.8502 10.2408C11.2093 11.2001 10.2982 11.9478 9.23232 12.3893C8.16642 12.8308 6.99353 12.9463 5.86198 12.7212C4.73042 12.4962 3.69102 11.9406 2.87521 11.1248C2.05941 10.309 1.50384 9.26958 1.27876 8.13803C1.05367 7.00647 1.16919 5.83358 1.61071 4.76768C2.05222 3.70178 2.79989 2.79074 3.75918 2.14976ZM7.00002 4.8611C6.84594 4.85908 6.69873 4.79698 6.58977 4.68801C6.48081 4.57905 6.4187 4.43185 6.41669 4.27776V3.88888C6.41669 3.73417 6.47815 3.58579 6.58754 3.4764C6.69694 3.367 6.84531 3.30554 7.00002 3.30554C7.15473 3.30554 7.3031 3.367 7.4125 3.4764C7.52189 3.58579 7.58335 3.73417 7.58335 3.88888V4.27776C7.58134 4.43185 7.51923 4.57905 7.41027 4.68801C7.30131 4.79698 7.1541 4.85908 7.00002 4.8611ZM7.00002 10.6945C6.84594 10.6925 6.69873 10.6304 6.58977 10.5214C6.48081 10.4124 6.4187 10.2652 6.41669 10.1111V6.22225C6.41669 6.06754 6.47815 5.91917 6.58754 5.80977C6.69694 5.70037 6.84531 5.63892 7.00002 5.63892C7.15473 5.63892 7.3031 5.70037 7.4125 5.80977C7.52189 5.91917 7.58335 6.06754 7.58335 6.22225V10.1111C7.58134 10.2652 7.51923 10.4124 7.41027 10.5214C7.30131 10.6304 7.1541 10.6925 7.00002 10.6945Z",
fill: "currentColor"
}, null, -1);
var _hoisted_2 = [_hoisted_1];
function render(_ctx, _cache, $props, $setup, $data, $options) {
return openBlock(), createElementBlock("svg", mergeProps({
width: "14",
height: "14",
viewBox: "0 0 14 14",
fill: "none",
xmlns: "http://www.w3.org/2000/svg"
}, _ctx.pti()), _hoisted_2, 16);
}
__name(render, "render");
script.render = render;
export {
script$1 as a,
script$3 as b,
script$2 as c,
script as s
};
//# sourceMappingURL=index-C_wOqB0f.js.map

1
web/assets/index-C_wOqB0f.js.map generated vendored

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
web/assets/index-CoOvI8ZH.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

1
web/assets/index-CwRXxFdA.js.map generated vendored

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show More