mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-08 21:49:54 +00:00
Compare commits
86 Commits
mark-dtype
...
range-type
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8822627a60 | ||
|
|
b615af1c65 | ||
|
|
40862c0776 | ||
|
|
50076f3439 | ||
|
|
61c2387436 | ||
|
|
7083484a48 | ||
|
|
4b1444fc7a | ||
|
|
8cbbea8f6a | ||
|
|
13917b3880 | ||
|
|
f21f6b2212 | ||
|
|
eb0686bbb6 | ||
|
|
5de94e70ec | ||
|
|
76b75f3ad7 | ||
|
|
0c63b4f6e3 | ||
|
|
7d437687c2 | ||
|
|
e2ddf28d78 | ||
|
|
076639fed9 | ||
|
|
55e6478526 | ||
|
|
537c10d231 | ||
|
|
8d723d2caa | ||
|
|
d113d1cc32 | ||
|
|
a500f1edac | ||
|
|
3f77450ef1 | ||
|
|
fc1fdf3389 | ||
|
|
b353a7c863 | ||
|
|
3696c5bad6 | ||
|
|
3a56201da5 | ||
|
|
6a2cdb817d | ||
|
|
85b7495135 | ||
|
|
225c52f6a4 | ||
|
|
b1fdbeb9a7 | ||
|
|
1dc64f3526 | ||
|
|
359559c913 | ||
|
|
8165485a17 | ||
|
|
b0fd65e884 | ||
|
|
2a1f402601 | ||
|
|
3eba2dcf2d | ||
|
|
404d7b9978 | ||
|
|
6580a6bc01 | ||
|
|
3b15651bc6 | ||
|
|
a55835f10c | ||
|
|
b53b10ea61 | ||
|
|
7d5534d8e5 | ||
|
|
5ebb0c2e0b | ||
|
|
a0a64c679f | ||
|
|
8e73678dae | ||
|
|
c2862b24af | ||
|
|
f9ec85f739 | ||
|
|
2d5fd3f5dd | ||
|
|
2d4970ff67 | ||
|
|
e87858e974 | ||
|
|
da6edb5a4e | ||
|
|
6265a239f3 | ||
|
|
d49420b3c7 | ||
|
|
ebf6b52e32 | ||
|
|
25b6d1d629 | ||
|
|
11c15d8832 | ||
|
|
b5d32e6ad2 | ||
|
|
a11f68dd3b | ||
|
|
dc719cde9c | ||
|
|
87cda1fc25 | ||
|
|
45d5c83a30 | ||
|
|
c646d211be | ||
|
|
589228e671 | ||
|
|
e4455fd43a | ||
|
|
f49856af57 | ||
|
|
82b868a45a | ||
|
|
8458ae2686 | ||
|
|
fd0261d2bc | ||
|
|
ab14541ef7 | ||
|
|
6589562ae3 | ||
|
|
fabed694a2 | ||
|
|
f6b869d7d3 | ||
|
|
56ff88f951 | ||
|
|
9fff091f35 | ||
|
|
dcd659590f | ||
|
|
b67ed2a45f | ||
|
|
06957022d4 | ||
|
|
b941913f1d | ||
|
|
cad24ce262 | ||
|
|
68d542cc06 | ||
|
|
735a0465e5 | ||
|
|
8b9d039f26 | ||
|
|
035414ede4 | ||
|
|
1a157e1f97 | ||
|
|
ed7c2c6579 |
36
.github/workflows/release-stable-all.yml
vendored
36
.github/workflows/release-stable-all.yml
vendored
@@ -20,29 +20,12 @@ jobs:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "11"
|
||||
python_patch: "12"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu128:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu128"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu128"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu128"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu126:
|
||||
permissions:
|
||||
contents: "write"
|
||||
@@ -76,3 +59,20 @@ jobs:
|
||||
rel_extra_name: ""
|
||||
test_release: false
|
||||
secrets: inherit
|
||||
|
||||
release_xpu:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release Intel XPU"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "xpu"
|
||||
python_minor: "13"
|
||||
python_patch: "12"
|
||||
rel_name: "intel"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
@@ -61,6 +61,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- NOTE: There are many more models supported than the list below, if you want to see what is supported see our templates list inside ComfyUI.
|
||||
- Image Models
|
||||
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
@@ -136,7 +137,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
|
||||
- Builds a new release using the latest stable core version
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
- Weekly frontend updates are merged into the core repository
|
||||
- Every 2+ weeks frontend updates are merged into the core repository
|
||||
- Features are frozen for the upcoming core release
|
||||
- Development continues for the next release cycle
|
||||
|
||||
@@ -232,7 +233,7 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||
|
||||
@@ -275,7 +276,7 @@ Nvidia users should install stable pytorch using this command:
|
||||
|
||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu130```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu132```
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
create_stub_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
reassign_asset_references,
|
||||
@@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
||||
UnenrichedReferenceRow,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_update_enrichment_level,
|
||||
count_active_siblings,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
convert_metadata_to_rows,
|
||||
@@ -80,6 +82,8 @@ __all__ = [
|
||||
"bulk_insert_references_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_enrichment_level",
|
||||
"count_active_siblings",
|
||||
"create_stub_asset",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"convert_metadata_to_rows",
|
||||
|
||||
@@ -78,6 +78,18 @@ def upsert_asset(
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def create_stub_asset(
|
||||
session: Session,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> Asset:
|
||||
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
|
||||
@@ -114,6 +114,23 @@ def get_reference_by_file_path(
|
||||
)
|
||||
|
||||
|
||||
def count_active_siblings(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
exclude_reference_id: str,
|
||||
) -> int:
|
||||
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||
return (
|
||||
session.query(AssetReference)
|
||||
.filter(
|
||||
AssetReference.asset_id == asset_id,
|
||||
AssetReference.id != exclude_reference_id,
|
||||
AssetReference.deleted_at.is_(None),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
def reference_exists_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
||||
delete_references_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_id,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
mark_references_missing_outside_prefixes,
|
||||
@@ -338,6 +339,7 @@ def build_asset_specs(
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
"job_id": None,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
@@ -426,6 +428,7 @@ def enrich_asset(
|
||||
except OSError:
|
||||
return new_level
|
||||
|
||||
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
@@ -489,6 +492,18 @@ def enrich_asset(
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||
# started (e.g. ingest_existing_file updated it), our results are
|
||||
# stale — discard them to avoid overwriting fresh registration data.
|
||||
ref = get_reference_by_id(session, reference_id)
|
||||
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||
session.rollback()
|
||||
logging.info(
|
||||
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||
reference_id,
|
||||
)
|
||||
return ENRICHMENT_STUB
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
@@ -77,7 +77,9 @@ class _AssetSeeder:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
# RLock is required because _run_scan() drains pending work while
|
||||
# holding _lock and re-enters start() which also acquires _lock.
|
||||
self._lock = threading.RLock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._last_progress: Progress | None = None
|
||||
@@ -92,6 +94,7 @@ class _AssetSeeder:
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
self._pending_enrich: dict | None = None
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
@@ -196,6 +199,42 @@ class _AssetSeeder:
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def enqueue_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||
|
||||
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||
request is stored and will run automatically when the current scan
|
||||
finishes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if started immediately, False if queued for later
|
||||
"""
|
||||
with self._lock:
|
||||
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||
return True
|
||||
if self._pending_enrich is not None:
|
||||
existing_roots = set(self._pending_enrich["roots"])
|
||||
existing_roots.update(roots)
|
||||
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||
self._pending_enrich["compute_hashes"] = (
|
||||
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||
)
|
||||
else:
|
||||
self._pending_enrich = {
|
||||
"roots": roots,
|
||||
"compute_hashes": compute_hashes,
|
||||
}
|
||||
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||
return False
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
@@ -381,9 +420,13 @@ class _AssetSeeder:
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
|
||||
def _reset_to_idle(self) -> None:
|
||||
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
@@ -594,9 +637,18 @@ class _AssetSeeder:
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
pending = self._pending_enrich
|
||||
if pending is not None:
|
||||
self._pending_enrich = None
|
||||
if not self.start_enrich(
|
||||
roots=pending["roots"],
|
||||
compute_hashes=pending["compute_hashes"],
|
||||
):
|
||||
logging.warning(
|
||||
"Pending enrich scan could not start (roots=%s)",
|
||||
pending["roots"],
|
||||
)
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
@@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
ingest_existing_file,
|
||||
register_output_files,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
@@ -72,6 +74,8 @@ __all__ = [
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"ingest_existing_file",
|
||||
"register_output_files",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
|
||||
@@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
job_id: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
@@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
|
||||
name: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
job_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
@@ -167,6 +169,7 @@ def batch_insert_seed_assets(
|
||||
"name": spec["info_name"],
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"job_id": spec.get("job_id"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
|
||||
@@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
count_active_siblings,
|
||||
create_stub_asset,
|
||||
ensure_tags_exist,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_file_path,
|
||||
@@ -23,7 +26,8 @@ from app.assets.database.queries import (
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.helpers import get_utc_now, normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
@@ -130,6 +134,102 @@ def _ingest_file_from_path(
|
||||
)
|
||||
|
||||
|
||||
def register_output_files(
|
||||
file_paths: Sequence[str],
|
||||
user_metadata: UserMetadata = None,
|
||||
job_id: str | None = None,
|
||||
) -> int:
|
||||
"""Register a batch of output file paths as assets.
|
||||
|
||||
Returns the number of files successfully registered.
|
||||
"""
|
||||
registered = 0
|
||||
for abs_path in file_paths:
|
||||
if not os.path.isfile(abs_path):
|
||||
continue
|
||||
try:
|
||||
if ingest_existing_file(
|
||||
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||
):
|
||||
registered += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to register output: %s", abs_path)
|
||||
return registered
|
||||
|
||||
|
||||
def ingest_existing_file(
|
||||
abs_path: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
extra_tags: Sequence[str] = (),
|
||||
owner_id: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Register an existing on-disk file as an asset stub.
|
||||
|
||||
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||
|
||||
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||
UX visibility.
|
||||
|
||||
Returns True if a row was inserted or updated, False otherwise.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||
|
||||
with create_session() as session:
|
||||
existing_ref = get_reference_by_file_path(session, locator)
|
||||
if existing_ref is not None:
|
||||
now = get_utc_now()
|
||||
existing_ref.mtime_ns = mtime_ns
|
||||
existing_ref.job_id = job_id
|
||||
existing_ref.is_missing = False
|
||||
existing_ref.deleted_at = None
|
||||
existing_ref.updated_at = now
|
||||
existing_ref.enrichment_level = 0
|
||||
|
||||
asset = existing_ref.asset
|
||||
if asset:
|
||||
# If other refs share this asset, detach to a new stub
|
||||
# instead of mutating the shared row.
|
||||
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||
if siblings > 0:
|
||||
new_asset = create_stub_asset(
|
||||
session,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type or asset.mime_type,
|
||||
)
|
||||
existing_ref.asset_id = new_asset.id
|
||||
else:
|
||||
asset.hash = None
|
||||
asset.size_bytes = size_bytes
|
||||
if mime_type:
|
||||
asset.mime_type = mime_type
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
spec = {
|
||||
"abs_path": abs_path,
|
||||
"size_bytes": size_bytes,
|
||||
"mtime_ns": mtime_ns,
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": os.path.basename(abs_path),
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": mime_type,
|
||||
"job_id": job_id,
|
||||
}
|
||||
if tags:
|
||||
ensure_tags_exist(session, tags)
|
||||
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||
session.commit()
|
||||
return result.won_paths > 0
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
|
||||
@@ -93,12 +93,13 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "models"], str]:
|
||||
) -> tuple[Literal["input", "output", "temp", "models"], str]:
|
||||
"""Determine which root category a file path belongs to.
|
||||
|
||||
Categories:
|
||||
- 'input': under folder_paths.get_input_directory()
|
||||
- 'output': under folder_paths.get_output_directory()
|
||||
- 'temp': under folder_paths.get_temp_directory()
|
||||
- 'models': under any base path from get_comfy_models_folders()
|
||||
|
||||
Returns:
|
||||
@@ -129,7 +130,12 @@ def get_asset_category_and_relative_path(
|
||||
if _check_is_within(fp_abs, output_base):
|
||||
return "output", _compute_relative(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
# 3) temp
|
||||
temp_base = os.path.abspath(folder_paths.get_temp_directory())
|
||||
if _check_is_within(fp_abs, temp_base):
|
||||
return "temp", _compute_relative(fp_abs, temp_base)
|
||||
|
||||
# 4) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
@@ -146,7 +152,7 @@ def get_asset_category_and_relative_path(
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
||||
f"Path is not within input, output, temp, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
@@ -0,0 +1,90 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
uniform float u_float5;
|
||||
uniform float u_float6;
|
||||
uniform float u_float7;
|
||||
uniform float u_float8;
|
||||
uniform bool u_bool0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(c.r, max(c.g, c.b));
|
||||
float minC = min(c.r, min(c.g, c.b));
|
||||
float l = (maxC + minC) * 0.5;
|
||||
if (maxC == minC) return vec3(0.0, 0.0, l);
|
||||
float d = maxC - minC;
|
||||
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
|
||||
float h;
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / d + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / d + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
if (t < 0.0) t += 1.0;
|
||||
if (t > 1.0) t -= 1.0;
|
||||
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 1.0 / 2.0) return q;
|
||||
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
float h = hsl.x, s = hsl.y, l = hsl.z;
|
||||
if (s == 0.0) return vec3(l);
|
||||
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
|
||||
float p = 2.0 * l - q;
|
||||
return vec3(
|
||||
hue2rgb(p, q, h + 1.0 / 3.0),
|
||||
hue2rgb(p, q, h),
|
||||
hue2rgb(p, q, h - 1.0 / 3.0)
|
||||
);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
|
||||
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
|
||||
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
|
||||
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float lightness = (maxC + minC) * 0.5;
|
||||
|
||||
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
|
||||
const float a = 0.25;
|
||||
const float b = 0.333;
|
||||
const float scale = 0.7;
|
||||
|
||||
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
|
||||
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
|
||||
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
|
||||
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
|
||||
|
||||
color += sw * shadows + mw * midtones + hw * highlights;
|
||||
|
||||
if (u_bool0) {
|
||||
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
|
||||
hsl.z = lightness;
|
||||
color = hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
49
blueprints/.glsl/Color_Curves_8.frag
Normal file
49
blueprints/.glsl/Color_Curves_8.frag
Normal file
@@ -0,0 +1,49 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
|
||||
uniform sampler2D u_curve1; // Red channel curve
|
||||
uniform sampler2D u_curve2; // Green channel curve
|
||||
uniform sampler2D u_curve3; // Blue channel curve
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// GIMP-compatible curve lookup with manual linear interpolation.
|
||||
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
|
||||
// index = value * (n_samples - 1)
|
||||
// f = fract(index)
|
||||
// result = (1-f) * samples[floor] + f * samples[ceil]
|
||||
//
|
||||
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
|
||||
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
|
||||
float applyCurve(sampler2D curve, float value) {
|
||||
value = clamp(value, 0.0, 1.0);
|
||||
|
||||
float pos = value * 255.0;
|
||||
int lo = int(floor(pos));
|
||||
int hi = min(lo + 1, 255);
|
||||
float f = pos - float(lo);
|
||||
|
||||
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
|
||||
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
|
||||
|
||||
return a + f * (b - a);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// GIMP order: per-channel curves first, then RGB master curve.
|
||||
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
|
||||
// dest = colors_curve( channel_curve( src ) )
|
||||
float tmp_r = applyCurve(u_curve1, color.r);
|
||||
float tmp_g = applyCurve(u_curve2, color.g);
|
||||
float tmp_b = applyCurve(u_curve3, color.b);
|
||||
color.r = applyCurve(u_curve0, tmp_r);
|
||||
color.g = applyCurve(u_curve0, tmp_g);
|
||||
color.b = applyCurve(u_curve0, tmp_b);
|
||||
|
||||
fragColor0 = vec4(color.rgb, color.a);
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1136
blueprints/Color Balance.json
Normal file
1136
blueprints/Color Balance.json
Normal file
File diff suppressed because it is too large
Load Diff
615
blueprints/Color Curves.json
Normal file
615
blueprints/Color Curves.json
Normal file
@@ -0,0 +1,615 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 10,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 10,
|
||||
"type": "d5c462c8-1372-4af8-84f2-547c83470d04",
|
||||
"pos": [
|
||||
3610,
|
||||
-2630
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
420
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "image",
|
||||
"localized_name": "images.image0",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"label": "IMAGE",
|
||||
"localized_name": "IMAGE0",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"4",
|
||||
"curve"
|
||||
],
|
||||
[
|
||||
"5",
|
||||
"curve"
|
||||
],
|
||||
[
|
||||
"6",
|
||||
"curve"
|
||||
],
|
||||
[
|
||||
"7",
|
||||
"curve"
|
||||
]
|
||||
]
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Color Curves"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "d5c462c8-1372-4af8-84f2-547c83470d04",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 9,
|
||||
"lastLinkId": 38,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Color Curves",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
2660,
|
||||
-4500,
|
||||
120,
|
||||
60
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
4270,
|
||||
-4500,
|
||||
120,
|
||||
60
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "abc345b7-f55e-4f32-a11d-3aa4c2b0936b",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
29,
|
||||
34
|
||||
],
|
||||
"localized_name": "images.image0",
|
||||
"label": "image",
|
||||
"pos": [
|
||||
2760,
|
||||
-4480
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "eb0ec079-46da-4408-8263-9ef85569d33d",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
28
|
||||
],
|
||||
"localized_name": "IMAGE0",
|
||||
"label": "IMAGE",
|
||||
"pos": [
|
||||
4290,
|
||||
-4480
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 4,
|
||||
"type": "CurveEditor",
|
||||
"pos": [
|
||||
3060,
|
||||
-4500
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
200
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "curve",
|
||||
"localized_name": "curve",
|
||||
"name": "curve",
|
||||
"type": "CURVE",
|
||||
"widget": {
|
||||
"name": "curve"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "histogram",
|
||||
"localized_name": "histogram",
|
||||
"name": "histogram",
|
||||
"type": "HISTOGRAM",
|
||||
"shape": 7,
|
||||
"link": 35
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "CURVE",
|
||||
"name": "CURVE",
|
||||
"type": "CURVE",
|
||||
"links": [
|
||||
30
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "RGB Master",
|
||||
"properties": {
|
||||
"Node name for S&R": "CurveEditor"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "CurveEditor",
|
||||
"pos": [
|
||||
3060,
|
||||
-4250
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
200
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "curve",
|
||||
"localized_name": "curve",
|
||||
"name": "curve",
|
||||
"type": "CURVE",
|
||||
"widget": {
|
||||
"name": "curve"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "histogram",
|
||||
"localized_name": "histogram",
|
||||
"name": "histogram",
|
||||
"type": "HISTOGRAM",
|
||||
"shape": 7,
|
||||
"link": 36
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "CURVE",
|
||||
"name": "CURVE",
|
||||
"type": "CURVE",
|
||||
"links": [
|
||||
31
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Red",
|
||||
"properties": {
|
||||
"Node name for S&R": "CurveEditor"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"type": "CurveEditor",
|
||||
"pos": [
|
||||
3060,
|
||||
-4000
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
200
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "curve",
|
||||
"localized_name": "curve",
|
||||
"name": "curve",
|
||||
"type": "CURVE",
|
||||
"widget": {
|
||||
"name": "curve"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "histogram",
|
||||
"localized_name": "histogram",
|
||||
"name": "histogram",
|
||||
"type": "HISTOGRAM",
|
||||
"shape": 7,
|
||||
"link": 37
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "CURVE",
|
||||
"name": "CURVE",
|
||||
"type": "CURVE",
|
||||
"links": [
|
||||
32
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Green",
|
||||
"properties": {
|
||||
"Node name for S&R": "CurveEditor"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "CurveEditor",
|
||||
"pos": [
|
||||
3060,
|
||||
-3750
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
200
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "curve",
|
||||
"localized_name": "curve",
|
||||
"name": "curve",
|
||||
"type": "CURVE",
|
||||
"widget": {
|
||||
"name": "curve"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "histogram",
|
||||
"localized_name": "histogram",
|
||||
"name": "histogram",
|
||||
"type": "HISTOGRAM",
|
||||
"shape": 7,
|
||||
"link": 38
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "CURVE",
|
||||
"name": "CURVE",
|
||||
"type": "CURVE",
|
||||
"links": [
|
||||
33
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Blue",
|
||||
"properties": {
|
||||
"Node name for S&R": "CurveEditor"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "GLSLShader",
|
||||
"pos": [
|
||||
3590,
|
||||
-4500
|
||||
],
|
||||
"size": [
|
||||
420,
|
||||
500
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "image0",
|
||||
"localized_name": "images.image0",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"link": 29
|
||||
},
|
||||
{
|
||||
"label": "image1",
|
||||
"localized_name": "images.image1",
|
||||
"name": "images.image1",
|
||||
"shape": 7,
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "u_curve0",
|
||||
"localized_name": "curves.u_curve0",
|
||||
"name": "curves.u_curve0",
|
||||
"shape": 7,
|
||||
"type": "CURVE",
|
||||
"link": 30
|
||||
},
|
||||
{
|
||||
"label": "u_curve1",
|
||||
"localized_name": "curves.u_curve1",
|
||||
"name": "curves.u_curve1",
|
||||
"shape": 7,
|
||||
"type": "CURVE",
|
||||
"link": 31
|
||||
},
|
||||
{
|
||||
"label": "u_curve2",
|
||||
"localized_name": "curves.u_curve2",
|
||||
"name": "curves.u_curve2",
|
||||
"shape": 7,
|
||||
"type": "CURVE",
|
||||
"link": 32
|
||||
},
|
||||
{
|
||||
"label": "u_curve3",
|
||||
"localized_name": "curves.u_curve3",
|
||||
"name": "curves.u_curve3",
|
||||
"shape": 7,
|
||||
"type": "CURVE",
|
||||
"link": 33
|
||||
},
|
||||
{
|
||||
"localized_name": "fragment_shader",
|
||||
"name": "fragment_shader",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "fragment_shader"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "size_mode",
|
||||
"name": "size_mode",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "size_mode"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE0",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
28
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "IMAGE1",
|
||||
"name": "IMAGE1",
|
||||
"type": "IMAGE",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"localized_name": "IMAGE2",
|
||||
"name": "IMAGE2",
|
||||
"type": "IMAGE",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"localized_name": "IMAGE3",
|
||||
"name": "IMAGE3",
|
||||
"type": "IMAGE",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "GLSLShader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform sampler2D u_curve0; // RGB master curve (256x1 LUT)\nuniform sampler2D u_curve1; // Red channel curve\nuniform sampler2D u_curve2; // Green channel curve\nuniform sampler2D u_curve3; // Blue channel curve\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\n// GIMP-compatible curve lookup with manual linear interpolation.\n// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:\n// index = value * (n_samples - 1)\n// f = fract(index)\n// result = (1-f) * samples[floor] + f * samples[ceil]\n//\n// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues\n// that occur with texture() + GL_LINEAR on small 256x1 LUTs.\nfloat applyCurve(sampler2D curve, float value) {\n value = clamp(value, 0.0, 1.0);\n\n float pos = value * 255.0;\n int lo = int(floor(pos));\n int hi = min(lo + 1, 255);\n float f = pos - float(lo);\n\n float a = texelFetch(curve, ivec2(lo, 0), 0).r;\n float b = texelFetch(curve, ivec2(hi, 0), 0).r;\n\n return a + f * (b - a);\n}\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n\n // GIMP order: per-channel curves first, then RGB master curve.\n // See gimp_curve_map_pixels() default case in gimpcurve-map.c:\n // dest = colors_curve( channel_curve( src ) )\n float tmp_r = applyCurve(u_curve1, color.r);\n float tmp_g = applyCurve(u_curve2, color.g);\n float tmp_b = applyCurve(u_curve3, color.b);\n color.r = applyCurve(u_curve0, tmp_r);\n color.g = applyCurve(u_curve0, tmp_g);\n color.b = applyCurve(u_curve0, tmp_b);\n\n fragColor0 = vec4(color.rgb, color.a);\n}\n",
|
||||
"from_input"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"type": "ImageHistogram",
|
||||
"pos": [
|
||||
2800,
|
||||
-4300
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
150
|
||||
],
|
||||
"flags": {},
|
||||
"order": 5,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "image",
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 34
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "HISTOGRAM",
|
||||
"name": "rgb",
|
||||
"type": "HISTOGRAM",
|
||||
"links": [
|
||||
35
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "HISTOGRAM",
|
||||
"name": "luminance",
|
||||
"type": "HISTOGRAM",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"localized_name": "HISTOGRAM",
|
||||
"name": "red",
|
||||
"type": "HISTOGRAM",
|
||||
"links": [
|
||||
36
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "HISTOGRAM",
|
||||
"name": "green",
|
||||
"type": "HISTOGRAM",
|
||||
"links": [
|
||||
37
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "HISTOGRAM",
|
||||
"name": "blue",
|
||||
"type": "HISTOGRAM",
|
||||
"links": [
|
||||
38
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "ImageHistogram"
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 29,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 8,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 28,
|
||||
"origin_id": 8,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"origin_id": 4,
|
||||
"origin_slot": 0,
|
||||
"target_id": 8,
|
||||
"target_slot": 2,
|
||||
"type": "CURVE"
|
||||
},
|
||||
{
|
||||
"id": 31,
|
||||
"origin_id": 5,
|
||||
"origin_slot": 0,
|
||||
"target_id": 8,
|
||||
"target_slot": 3,
|
||||
"type": "CURVE"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"origin_id": 6,
|
||||
"origin_slot": 0,
|
||||
"target_id": 8,
|
||||
"target_slot": 4,
|
||||
"type": "CURVE"
|
||||
},
|
||||
{
|
||||
"id": 33,
|
||||
"origin_id": 7,
|
||||
"origin_slot": 0,
|
||||
"target_id": 8,
|
||||
"target_slot": 5,
|
||||
"type": "CURVE"
|
||||
},
|
||||
{
|
||||
"id": 34,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 9,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 35,
|
||||
"origin_id": 9,
|
||||
"origin_slot": 0,
|
||||
"target_id": 4,
|
||||
"target_slot": 1,
|
||||
"type": "HISTOGRAM"
|
||||
},
|
||||
{
|
||||
"id": 36,
|
||||
"origin_id": 9,
|
||||
"origin_slot": 2,
|
||||
"target_id": 5,
|
||||
"target_slot": 1,
|
||||
"type": "HISTOGRAM"
|
||||
},
|
||||
{
|
||||
"id": 37,
|
||||
"origin_id": 9,
|
||||
"origin_slot": 3,
|
||||
"target_id": 6,
|
||||
"target_slot": 1,
|
||||
"type": "HISTOGRAM"
|
||||
},
|
||||
{
|
||||
"id": 38,
|
||||
"origin_id": 9,
|
||||
"origin_slot": 4,
|
||||
"target_id": 7,
|
||||
"target_slot": 1,
|
||||
"type": "HISTOGRAM"
|
||||
}
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +1,322 @@
|
||||
{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}}
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 29,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 29,
|
||||
"type": "4c9d6ea4-b912-40e5-8766-6793a9758c53",
|
||||
"pos": [
|
||||
1970,
|
||||
-230
|
||||
],
|
||||
"size": [
|
||||
180,
|
||||
86
|
||||
],
|
||||
"flags": {},
|
||||
"order": 5,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "image",
|
||||
"localized_name": "images.image0",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"label": "R",
|
||||
"localized_name": "IMAGE0",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"label": "G",
|
||||
"localized_name": "IMAGE1",
|
||||
"name": "IMAGE1",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"label": "B",
|
||||
"localized_name": "IMAGE2",
|
||||
"name": "IMAGE2",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"label": "A",
|
||||
"localized_name": "IMAGE3",
|
||||
"name": "IMAGE3",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"title": "Image Channels",
|
||||
"properties": {
|
||||
"proxyWidgets": []
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 28,
|
||||
"lastLinkId": 39,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Image Channels",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
1820,
|
||||
-185,
|
||||
120,
|
||||
60
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
2460,
|
||||
-215,
|
||||
120,
|
||||
120
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
39
|
||||
],
|
||||
"localized_name": "images.image0",
|
||||
"label": "image",
|
||||
"pos": [
|
||||
1920,
|
||||
-165
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
26
|
||||
],
|
||||
"localized_name": "IMAGE0",
|
||||
"label": "R",
|
||||
"pos": [
|
||||
2480,
|
||||
-195
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "fb44a77e-0522-43e9-9527-82e7465b3596",
|
||||
"name": "IMAGE1",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
27
|
||||
],
|
||||
"localized_name": "IMAGE1",
|
||||
"label": "G",
|
||||
"pos": [
|
||||
2480,
|
||||
-175
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "81460ee6-0131-402a-874f-6bf3001fc4ff",
|
||||
"name": "IMAGE2",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
28
|
||||
],
|
||||
"localized_name": "IMAGE2",
|
||||
"label": "B",
|
||||
"pos": [
|
||||
2480,
|
||||
-155
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "ae690246-80d4-4951-b1d9-9306d8a77417",
|
||||
"name": "IMAGE3",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
29
|
||||
],
|
||||
"localized_name": "IMAGE3",
|
||||
"label": "A",
|
||||
"pos": [
|
||||
2480,
|
||||
-135
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 23,
|
||||
"type": "GLSLShader",
|
||||
"pos": [
|
||||
2000,
|
||||
-330
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
172
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "image",
|
||||
"localized_name": "images.image0",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"link": 39
|
||||
},
|
||||
{
|
||||
"localized_name": "fragment_shader",
|
||||
"name": "fragment_shader",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "fragment_shader"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "size_mode",
|
||||
"name": "size_mode",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "size_mode"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "image1",
|
||||
"localized_name": "images.image1",
|
||||
"name": "images.image1",
|
||||
"shape": 7,
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"label": "R",
|
||||
"localized_name": "IMAGE0",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
26
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "G",
|
||||
"localized_name": "IMAGE1",
|
||||
"name": "IMAGE1",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
27
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "B",
|
||||
"localized_name": "IMAGE2",
|
||||
"name": "IMAGE2",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
28
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "A",
|
||||
"localized_name": "IMAGE3",
|
||||
"name": "IMAGE3",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
29
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "GLSLShader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n",
|
||||
"from_input"
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 39,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 23,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"origin_id": 23,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"origin_id": 23,
|
||||
"origin_slot": 1,
|
||||
"target_id": -20,
|
||||
"target_slot": 1,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 28,
|
||||
"origin_id": 23,
|
||||
"origin_slot": 2,
|
||||
"target_id": -20,
|
||||
"target_slot": 2,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 29,
|
||||
"origin_id": 23,
|
||||
"origin_slot": 3,
|
||||
"target_id": -20,
|
||||
"target_slot": 3,
|
||||
"type": "IMAGE"
|
||||
}
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +1,278 @@
|
||||
{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}}
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 15,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 15,
|
||||
"type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471",
|
||||
"pos": [
|
||||
-1490,
|
||||
2040
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
260
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "prompt",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "prompt"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "reference images",
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"title": "Prompt Enhance",
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"-1",
|
||||
"prompt"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.14.1"
|
||||
},
|
||||
"widgets_values": [
|
||||
""
|
||||
]
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 15,
|
||||
"lastLinkId": 14,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Prompt Enhance",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-2170,
|
||||
2110,
|
||||
138.876953125,
|
||||
80
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
-640,
|
||||
2110,
|
||||
120,
|
||||
60
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "aeab7216-00e0-4528-a09b-bba50845c5a6",
|
||||
"name": "prompt",
|
||||
"type": "STRING",
|
||||
"linkIds": [
|
||||
11
|
||||
],
|
||||
"pos": [
|
||||
-2051.123046875,
|
||||
2130
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "7b73fd36-aa31-4771-9066-f6c83879994b",
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
14
|
||||
],
|
||||
"label": "reference images",
|
||||
"pos": [
|
||||
-2051.123046875,
|
||||
2150
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "c7b0d930-68a1-48d1-b496-0519e5837064",
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"linkIds": [
|
||||
13
|
||||
],
|
||||
"pos": [
|
||||
-620,
|
||||
2130
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 11,
|
||||
"type": "GeminiNode",
|
||||
"pos": [
|
||||
-1560,
|
||||
1990
|
||||
],
|
||||
"size": [
|
||||
470,
|
||||
470
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "images",
|
||||
"name": "images",
|
||||
"shape": 7,
|
||||
"type": "IMAGE",
|
||||
"link": 14
|
||||
},
|
||||
{
|
||||
"localized_name": "audio",
|
||||
"name": "audio",
|
||||
"shape": 7,
|
||||
"type": "AUDIO",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "video",
|
||||
"name": "video",
|
||||
"shape": 7,
|
||||
"type": "VIDEO",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "files",
|
||||
"name": "files",
|
||||
"shape": 7,
|
||||
"type": "GEMINI_INPUT_FILES",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "prompt",
|
||||
"name": "prompt",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "prompt"
|
||||
},
|
||||
"link": 11
|
||||
},
|
||||
{
|
||||
"localized_name": "model",
|
||||
"name": "model",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "seed",
|
||||
"name": "seed",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "seed"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "system_prompt",
|
||||
"name": "system_prompt",
|
||||
"shape": 7,
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "system_prompt"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "STRING",
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"links": [
|
||||
13
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.14.1",
|
||||
"Node name for S&R": "GeminiNode"
|
||||
},
|
||||
"widgets_values": [
|
||||
"",
|
||||
"gemini-3-pro-preview",
|
||||
42,
|
||||
"randomize",
|
||||
"You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"
|
||||
],
|
||||
"color": "#432",
|
||||
"bgcolor": "#653"
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 11,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 11,
|
||||
"target_slot": 4,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"origin_id": 11,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 11,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
}
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Prompt enhance"
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
|
||||
@@ -1 +1,309 @@
|
||||
{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}}
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 25,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 25,
|
||||
"type": "621ba4e2-22a8-482d-a369-023753198b7b",
|
||||
"pos": [
|
||||
4610,
|
||||
-790
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
58
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "image",
|
||||
"localized_name": "images.image0",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"label": "IMAGE",
|
||||
"localized_name": "IMAGE0",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"title": "Sharpen",
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"24",
|
||||
"value"
|
||||
]
|
||||
]
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "621ba4e2-22a8-482d-a369-023753198b7b",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 24,
|
||||
"lastLinkId": 36,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Sharpen",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
4090,
|
||||
-825,
|
||||
120,
|
||||
60
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
5150,
|
||||
-825,
|
||||
120,
|
||||
60
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
34
|
||||
],
|
||||
"localized_name": "images.image0",
|
||||
"label": "image",
|
||||
"pos": [
|
||||
4190,
|
||||
-805
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
35
|
||||
],
|
||||
"localized_name": "IMAGE0",
|
||||
"label": "IMAGE",
|
||||
"pos": [
|
||||
5170,
|
||||
-805
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 24,
|
||||
"type": "PrimitiveFloat",
|
||||
"pos": [
|
||||
4280,
|
||||
-1240
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
58
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "strength",
|
||||
"localized_name": "value",
|
||||
"name": "value",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "value"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "FLOAT",
|
||||
"name": "FLOAT",
|
||||
"type": "FLOAT",
|
||||
"links": [
|
||||
36
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "PrimitiveFloat",
|
||||
"min": 0,
|
||||
"max": 3,
|
||||
"precision": 2,
|
||||
"step": 0.05
|
||||
},
|
||||
"widgets_values": [
|
||||
0.5
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 23,
|
||||
"type": "GLSLShader",
|
||||
"pos": [
|
||||
4570,
|
||||
-1240
|
||||
],
|
||||
"size": [
|
||||
370,
|
||||
192
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "image0",
|
||||
"localized_name": "images.image0",
|
||||
"name": "images.image0",
|
||||
"type": "IMAGE",
|
||||
"link": 34
|
||||
},
|
||||
{
|
||||
"label": "image1",
|
||||
"localized_name": "images.image1",
|
||||
"name": "images.image1",
|
||||
"shape": 7,
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "u_float0",
|
||||
"localized_name": "floats.u_float0",
|
||||
"name": "floats.u_float0",
|
||||
"shape": 7,
|
||||
"type": "FLOAT",
|
||||
"link": 36
|
||||
},
|
||||
{
|
||||
"label": "u_float1",
|
||||
"localized_name": "floats.u_float1",
|
||||
"name": "floats.u_float1",
|
||||
"shape": 7,
|
||||
"type": "FLOAT",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "u_int0",
|
||||
"localized_name": "ints.u_int0",
|
||||
"name": "ints.u_int0",
|
||||
"shape": 7,
|
||||
"type": "INT",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "fragment_shader",
|
||||
"name": "fragment_shader",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "fragment_shader"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "size_mode",
|
||||
"name": "size_mode",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "size_mode"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE0",
|
||||
"name": "IMAGE0",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
35
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "IMAGE1",
|
||||
"name": "IMAGE1",
|
||||
"type": "IMAGE",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"localized_name": "IMAGE2",
|
||||
"name": "IMAGE2",
|
||||
"type": "IMAGE",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"localized_name": "IMAGE3",
|
||||
"name": "IMAGE3",
|
||||
"type": "IMAGE",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "GLSLShader"
|
||||
},
|
||||
"widgets_values": [
|
||||
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}",
|
||||
"from_input"
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 36,
|
||||
"origin_id": 24,
|
||||
"origin_slot": 0,
|
||||
"target_id": 23,
|
||||
"target_slot": 2,
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"id": 34,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 23,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 35,
|
||||
"origin_id": 23,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
}
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Sharpen"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +1,420 @@
|
||||
{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}}
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 13,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 13,
|
||||
"type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1",
|
||||
"pos": [
|
||||
1120,
|
||||
330
|
||||
],
|
||||
"size": [
|
||||
240,
|
||||
58
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "video",
|
||||
"name": "video",
|
||||
"type": "VIDEO",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "VIDEO",
|
||||
"name": "VIDEO",
|
||||
"type": "VIDEO",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"title": "Video Upscale(GAN x4)",
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"-1",
|
||||
"model_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.14.1"
|
||||
},
|
||||
"widgets_values": [
|
||||
"RealESRGAN_x4plus.safetensors"
|
||||
]
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 13,
|
||||
"lastLinkId": 19,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Video Upscale(GAN x4)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
550,
|
||||
460,
|
||||
120,
|
||||
80
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
1490,
|
||||
460,
|
||||
120,
|
||||
60
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6",
|
||||
"name": "video",
|
||||
"type": "VIDEO",
|
||||
"linkIds": [
|
||||
10
|
||||
],
|
||||
"localized_name": "video",
|
||||
"pos": [
|
||||
650,
|
||||
480
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "2e23a087-caa8-4d65-99e6-662761aa905a",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
19
|
||||
],
|
||||
"pos": [
|
||||
650,
|
||||
500
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70",
|
||||
"name": "VIDEO",
|
||||
"type": "VIDEO",
|
||||
"linkIds": [
|
||||
15
|
||||
],
|
||||
"localized_name": "VIDEO",
|
||||
"pos": [
|
||||
1510,
|
||||
480
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 2,
|
||||
"type": "ImageUpscaleWithModel",
|
||||
"pos": [
|
||||
1110,
|
||||
450
|
||||
],
|
||||
"size": [
|
||||
320,
|
||||
46
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "upscale_model",
|
||||
"name": "upscale_model",
|
||||
"type": "UPSCALE_MODEL",
|
||||
"link": 1
|
||||
},
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 14
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
13
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.10.0",
|
||||
"Node name for S&R": "ImageUpscaleWithModel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"type": "CreateVideo",
|
||||
"pos": [
|
||||
1110,
|
||||
550
|
||||
],
|
||||
"size": [
|
||||
320,
|
||||
78
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "images",
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"link": 13
|
||||
},
|
||||
{
|
||||
"localized_name": "audio",
|
||||
"name": "audio",
|
||||
"shape": 7,
|
||||
"type": "AUDIO",
|
||||
"link": 16
|
||||
},
|
||||
{
|
||||
"localized_name": "fps",
|
||||
"name": "fps",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "fps"
|
||||
},
|
||||
"link": 12
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "VIDEO",
|
||||
"name": "VIDEO",
|
||||
"type": "VIDEO",
|
||||
"links": [
|
||||
15
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.10.0",
|
||||
"Node name for S&R": "CreateVideo"
|
||||
},
|
||||
"widgets_values": [
|
||||
30
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"type": "GetVideoComponents",
|
||||
"pos": [
|
||||
1110,
|
||||
330
|
||||
],
|
||||
"size": [
|
||||
320,
|
||||
70
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "video",
|
||||
"name": "video",
|
||||
"type": "VIDEO",
|
||||
"link": 10
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "images",
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
14
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "audio",
|
||||
"name": "audio",
|
||||
"type": "AUDIO",
|
||||
"links": [
|
||||
16
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "fps",
|
||||
"name": "fps",
|
||||
"type": "FLOAT",
|
||||
"links": [
|
||||
12
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.10.0",
|
||||
"Node name for S&R": "GetVideoComponents"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "UpscaleModelLoader",
|
||||
"pos": [
|
||||
750,
|
||||
450
|
||||
],
|
||||
"size": [
|
||||
280,
|
||||
60
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "model_name",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": 19
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "UPSCALE_MODEL",
|
||||
"name": "UPSCALE_MODEL",
|
||||
"type": "UPSCALE_MODEL",
|
||||
"links": [
|
||||
1
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.10.0",
|
||||
"Node name for S&R": "UpscaleModelLoader",
|
||||
"models": [
|
||||
{
|
||||
"name": "RealESRGAN_x4plus.safetensors",
|
||||
"url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors",
|
||||
"directory": "upscale_models"
|
||||
}
|
||||
]
|
||||
},
|
||||
"widgets_values": [
|
||||
"RealESRGAN_x4plus.safetensors"
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 1,
|
||||
"origin_id": 1,
|
||||
"origin_slot": 0,
|
||||
"target_id": 2,
|
||||
"target_slot": 0,
|
||||
"type": "UPSCALE_MODEL"
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
"origin_id": 10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 2,
|
||||
"target_slot": 1,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"origin_id": 2,
|
||||
"origin_slot": 0,
|
||||
"target_id": 11,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 16,
|
||||
"origin_id": 10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 11,
|
||||
"target_slot": 1,
|
||||
"type": "AUDIO"
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"origin_id": 10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 11,
|
||||
"target_slot": 2,
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 10,
|
||||
"target_slot": 0,
|
||||
"type": "VIDEO"
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
"origin_id": 11,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "VIDEO"
|
||||
},
|
||||
{
|
||||
"id": 19,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 1,
|
||||
"target_slot": 0,
|
||||
"type": "COMBO"
|
||||
}
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Enhance video"
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
|
||||
@@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
CACHE_RAM_AUTO_GB = -1.0
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
|
||||
@@ -93,6 +93,50 @@ class IndexListCallbacks:
|
||||
return {}
|
||||
|
||||
|
||||
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
|
||||
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
|
||||
return None
|
||||
cond_tensor = cond_value.cond
|
||||
if temporal_dim >= cond_tensor.ndim:
|
||||
return None
|
||||
|
||||
cond_size = cond_tensor.size(temporal_dim)
|
||||
|
||||
if temporal_scale == 1:
|
||||
expected_size = x_in.size(window.dim) - temporal_offset
|
||||
if cond_size != expected_size:
|
||||
return None
|
||||
|
||||
if temporal_offset == 0 and temporal_scale == 1:
|
||||
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
if temporal_scale > 1:
|
||||
scaled = []
|
||||
for i in indices:
|
||||
for k in range(temporal_scale):
|
||||
si = i * temporal_scale + k
|
||||
if si < cond_size:
|
||||
scaled.append(si)
|
||||
indices = scaled
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
idx = tuple([slice(None)] * temporal_dim + [indices])
|
||||
sliced = cond_tensor[idx].to(device)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if not handled and self._model is not None:
|
||||
result = self._model.resize_cond_for_context_window(
|
||||
cond_key, cond_value, window, x_in, device,
|
||||
retain_index_list=self.cond_retain_index_list)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
@@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
@@ -611,6 +611,7 @@ class AceStepDiTModel(nn.Module):
|
||||
intermediate_size,
|
||||
patch_size,
|
||||
audio_acoustic_hidden_dim,
|
||||
condition_dim=None,
|
||||
layer_types=None,
|
||||
sliding_window=128,
|
||||
rms_norm_eps=1e-6,
|
||||
@@ -640,7 +641,7 @@ class AceStepDiTModel(nn.Module):
|
||||
|
||||
self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
self.condition_embedder = Linear(condition_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
if layer_types is None:
|
||||
layer_types = ["full_attention"] * num_layers
|
||||
@@ -1035,6 +1036,9 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
fsq_dim=2048,
|
||||
fsq_levels=[8, 8, 8, 5, 5, 5],
|
||||
fsq_input_num_quantizers=1,
|
||||
encoder_hidden_size=2048,
|
||||
encoder_intermediate_size=6144,
|
||||
encoder_num_heads=16,
|
||||
audio_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
@@ -1054,24 +1058,24 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
|
||||
self.decoder = AceStepDiTModel(
|
||||
in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim,
|
||||
intermediate_size, patch_size, audio_acoustic_hidden_dim,
|
||||
intermediate_size, patch_size, audio_acoustic_hidden_dim, condition_dim=encoder_hidden_size,
|
||||
layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.encoder = AceStepConditionEncoder(
|
||||
text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers,
|
||||
num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
|
||||
text_hidden_dim, timbre_hidden_dim, encoder_hidden_size, num_lyric_layers, num_timbre_layers,
|
||||
encoder_num_heads, num_kv_heads, head_dim, encoder_intermediate_size, rms_norm_eps,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.tokenizer = AceStepAudioTokenizer(
|
||||
audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
|
||||
audio_acoustic_hidden_dim, encoder_hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.detokenizer = AudioTokenDetokenizer(
|
||||
hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
|
||||
encoder_hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
|
||||
self.null_condition_emb = nn.Parameter(torch.empty(1, 1, encoder_hidden_size, dtype=dtype, device=device))
|
||||
|
||||
def prepare_condition(
|
||||
self,
|
||||
|
||||
@@ -136,16 +136,7 @@ class ResBlock(nn.Module):
|
||||
ops.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
@@ -386,7 +386,7 @@ class Flux(nn.Module):
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
|
||||
@@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
|
||||
# Inject reference audio for ID-LoRA in-context conditioning
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
ref_audio_seq_len = 0
|
||||
if ref_audio is not None:
|
||||
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
|
||||
if ref_tokens.shape[0] < ax.shape[0]:
|
||||
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
|
||||
ref_audio_seq_len = ref_tokens.shape[1]
|
||||
B = ax.shape[0]
|
||||
|
||||
# Compute negative temporal positions matching ID-LoRA convention:
|
||||
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
|
||||
p = self.a_patchifier
|
||||
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
|
||||
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
|
||||
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
|
||||
time_offset = ref_end[-1].item() + tpl
|
||||
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
|
||||
|
||||
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
||||
additional_args["target_audio_seq_len"] = ax.shape[1]
|
||||
ax = torch.cat([ref_tokens, ax], dim=1)
|
||||
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
||||
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
@@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||
if ref_audio_seq_len > 0 and a_timestep is not None:
|
||||
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
||||
target_len = kwargs.get("target_audio_seq_len")
|
||||
if a_timestep.dim() <= 1:
|
||||
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
||||
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
|
||||
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
||||
if a_timestep is not None:
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
@@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Trim reference audio tokens before unpatchification
|
||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||
if ref_audio_seq_len > 0:
|
||||
ax = ax[:, ref_audio_seq_len:]
|
||||
if a_embedded_timestep.shape[1] > 1:
|
||||
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
@@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if isinstance(stride, int):
|
||||
self.time_stride = stride
|
||||
else:
|
||||
self.time_stride = stride[0]
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
@@ -58,16 +63,25 @@ class CausalConv3d(nn.Module):
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
input_length = sum([piece.shape[2] for piece in pieces])
|
||||
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
if needs_caching and cache_length == 0:
|
||||
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
if needs_caching and x.shape[2] >= cache_length:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
del pieces
|
||||
del cached
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
elif is_end:
|
||||
self.temporal_cache_state[tid] = (None, True)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
|
||||
@@ -233,10 +233,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
@@ -247,10 +244,14 @@ class Encoder(nn.Module):
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
@@ -282,9 +283,35 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
|
||||
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
|
||||
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
|
||||
|
||||
outputs = []
|
||||
samples = [sample[:, :, :1, :, :]]
|
||||
if sample.shape[2] > 1:
|
||||
chunk_t = max(2, max_chunk_size // frame_size)
|
||||
if chunk_t < 4:
|
||||
chunk_t = 2
|
||||
elif chunk_t < 8:
|
||||
chunk_t = 4
|
||||
else:
|
||||
chunk_t = (chunk_t // 8) * 8
|
||||
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
|
||||
for chunk_idx, chunk in enumerate(samples):
|
||||
if chunk_idx == len(samples) - 1:
|
||||
mark_conv3d_ended(self)
|
||||
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
|
||||
output = self._forward_chunk(chunk)
|
||||
if output is not None:
|
||||
outputs.append(output)
|
||||
|
||||
return torch_cat_if_needed(outputs, dim=2)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
@@ -297,7 +324,23 @@ class Encoder(nn.Module):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||
|
||||
def get_max_chunk_size(device: torch.device) -> int:
|
||||
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||
|
||||
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||
return MIN_CHUNK_SIZE
|
||||
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||
return MAX_CHUNK_SIZE
|
||||
|
||||
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||
)
|
||||
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@@ -457,6 +500,17 @@ class Decoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
|
||||
ts, hs, ws, to = 1, 1, 1, 0
|
||||
for block in self.up_blocks:
|
||||
if isinstance(block, DepthToSpaceUpsample):
|
||||
ts *= block.stride[0]
|
||||
hs *= block.stride[1]
|
||||
ws *= block.stride[2]
|
||||
if block.stride[0] > 1:
|
||||
to = to * block.stride[0] + 1
|
||||
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
@@ -478,11 +532,62 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def decode_output_shape(self, input_shape):
|
||||
c, (ts, hs, ws), to = self._output_scale
|
||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||
|
||||
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
|
||||
sample = sample_ref[0]
|
||||
sample_ref[0] = None
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
t = sample.shape[2]
|
||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||
output_offset[0] += t
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if ended:
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||
|
||||
if num_chunks == 1:
|
||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||
next_sample_ref = [sample]
|
||||
del sample
|
||||
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
return
|
||||
else:
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
output_buffer: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
@@ -497,6 +602,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
timestep_shift_scale = None
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
@@ -524,48 +630,18 @@ class Decoder(nn.Module):
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
output = []
|
||||
if output_buffer is None:
|
||||
output_buffer = torch.empty(
|
||||
self.decode_output_shape(sample.shape),
|
||||
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||
return
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if (ended):
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
return output_buffer
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
@@ -689,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.stride[0] == 2:
|
||||
tid = threading.get_ident()
|
||||
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
|
||||
if cached_input is not None:
|
||||
x = torch_cat_if_needed([cached_input, x], dim=2)
|
||||
cached_input = None
|
||||
|
||||
if self.stride[0] == 2 and pad_first:
|
||||
x = torch.cat(
|
||||
[x[:, :, :1, :, :], x], dim=2
|
||||
) # duplicate first frames for padding
|
||||
pad_first = False
|
||||
|
||||
if x.shape[2] < self.stride[0]:
|
||||
cached_input = x
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
return None
|
||||
|
||||
# skip connection
|
||||
x_in = rearrange(
|
||||
@@ -709,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
# conv
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2 and x.shape[2] == 1:
|
||||
if cached_x is not None:
|
||||
x = torch_cat_if_needed([cached_x, x], dim=2)
|
||||
cached_x = None
|
||||
else:
|
||||
cached_x = x
|
||||
x = None
|
||||
|
||||
x = x + x_in
|
||||
if x is not None:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
cached = add_exchange_cache(x, cached, x_in, dim=2)
|
||||
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
|
||||
return x
|
||||
|
||||
@@ -1050,6 +1150,8 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
comfy_has_chunked_io = True
|
||||
|
||||
def __init__(self, version=0, config=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -1192,14 +1294,15 @@ class VideoVAE(nn.Module):
|
||||
}
|
||||
return config
|
||||
|
||||
def encode(self, x):
|
||||
frames_count = x.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
def encode(self, x, device=None):
|
||||
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
||||
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
def decode_output_shape(self, input_shape):
|
||||
return self.decoder.decode_output_shape(input_shape)
|
||||
|
||||
def decode(self, x, output_buffer=None):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
|
||||
|
||||
@@ -155,6 +155,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
def __init__(self, embed_dim: int, **kwargs):
|
||||
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
||||
ddconfig = kwargs.pop("ddconfig")
|
||||
decoder_ddconfig = kwargs.pop("decoder_ddconfig", ddconfig)
|
||||
super().__init__(
|
||||
encoder_config={
|
||||
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
|
||||
@@ -162,7 +163,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
},
|
||||
decoder_config={
|
||||
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
|
||||
"params": ddconfig,
|
||||
"params": decoder_ddconfig,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -3,12 +3,9 @@ from ..diffusionmodules.openaimodel import Timestep
|
||||
import torch
|
||||
|
||||
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
|
||||
def __init__(self, *args, timestep_dim=256, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if clip_stats_path is None:
|
||||
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||
else:
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||
self.time_embed = Timestep(timestep_dim)
|
||||
|
||||
725
comfy/ldm/rt_detr/rtdetr_v4.py
Normal file
725
comfy/ldm/rt_detr/rtdetr_v4.py
Normal file
@@ -0,0 +1,725 @@
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
COCO_CLASSES = [
|
||||
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
|
||||
'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat',
|
||||
'dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack',
|
||||
'umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball',
|
||||
'kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket',
|
||||
'bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple',
|
||||
'sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair',
|
||||
'couch','potted plant','bed','dining table','toilet','tv','laptop','mouse',
|
||||
'remote','keyboard','cell phone','microwave','oven','toaster','sink',
|
||||
'refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush',
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HGNetv2 backbone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConvBNAct(nn.Module):
|
||||
"""Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem)."""
|
||||
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
|
||||
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||
self.act = nn.ReLU() if use_act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
class LightConvBNAct(nn.Module):
|
||||
def __init__(self, ic, oc, k, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations)
|
||||
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.conv1(x))
|
||||
|
||||
class _StemBlock(nn.Module):
|
||||
def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations)
|
||||
# stem2a/stem2b use kernel=2, stride=1, no internal padding;
|
||||
# padding is applied manually in forward (matching PaddlePaddle original)
|
||||
self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.pool = nn.MaxPool2d(2, 1, ceil_mode=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem1(x)
|
||||
x = F.pad(x, (0, 1, 0, 1)) # pad before pool and stem2a
|
||||
x2 = self.stem2a(x)
|
||||
x2 = F.pad(x2, (0, 1, 0, 1)) # pad before stem2b
|
||||
x2 = self.stem2b(x2)
|
||||
x1 = self.pool(x)
|
||||
return self.stem4(self.stem3(torch.cat([x1, x2], 1)))
|
||||
|
||||
|
||||
class _HG_Block(nn.Module):
|
||||
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.residual = residual
|
||||
if light:
|
||||
self.layers = nn.ModuleList(
|
||||
[LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
|
||||
total = ic + layer_num * mc
|
||||
|
||||
self.aggregation = nn.Sequential(
|
||||
ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations),
|
||||
ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations))
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
outs = [x]
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outs.append(x)
|
||||
x = self.aggregation(torch.cat(outs, 1))
|
||||
return x + identity if self.residual else x
|
||||
|
||||
|
||||
class _HG_Stage(nn.Module):
|
||||
# config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num
|
||||
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
if downsample:
|
||||
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
self.blocks = nn.Sequential(*[
|
||||
_HG_Block(ic if i == 0 else oc, mc, oc, layer_num,
|
||||
k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations)
|
||||
for i in range(num_blocks)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
return self.blocks(self.downsample(x))
|
||||
|
||||
|
||||
class HGNetv2(nn.Module):
|
||||
# B5 config: stem=[3,32,64], stages=[ic, mc, oc, blocks, down, light, k, layers]
|
||||
_STAGE_CFGS = [[64, 64, 128, 1, False, False, 3, 6],
|
||||
[128, 128, 512, 2, True, False, 3, 6],
|
||||
[512, 256, 1024, 5, True, True, 5, 6],
|
||||
[1024,512, 2048, 2, True, True, 5, 6]]
|
||||
|
||||
def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations)
|
||||
self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS])
|
||||
self.return_idx = list(return_idx)
|
||||
self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.stem(x)
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = stage(x)
|
||||
if i in self.return_idx:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder — HybridEncoder (dfine version: RepNCSPELAN4 + SCDown PAN)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConvNormLayer(nn.Module):
|
||||
"""Conv→act (expects pre-fused BN weights)."""
|
||||
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
p = (k - 1) // 2 if padding is None else padding
|
||||
self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype)
|
||||
self.act = nn.SiLU() if act == 'silu' else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
|
||||
class VGGBlock(nn.Module):
|
||||
"""Rep-VGG block (expects pre-fused weights)."""
|
||||
def __init__(self, ic, oc, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
|
||||
class CSPLayer(nn.Module):
|
||||
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
h = int(oc * expansion)
|
||||
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)])
|
||||
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x))
|
||||
|
||||
|
||||
class RepNCSPELAN4(nn.Module):
|
||||
"""CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder."""
|
||||
def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.c = c3 // 2
|
||||
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
|
||||
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
|
||||
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
y = list(self.cv1(x).split((self.c, self.c), 1))
|
||||
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
|
||||
return self.cv4(torch.cat(y, 1))
|
||||
|
||||
|
||||
class SCDown(nn.Module):
|
||||
"""Separable conv downsampling used in HybridEncoder PAN bottom-up path."""
|
||||
def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, query, key, value, attn_mask=None):
|
||||
optimized_attention = optimized_attention_for_device(query.device, False, small_input=True)
|
||||
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
|
||||
out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class _TransformerEncoderLayer(nn.Module):
|
||||
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
|
||||
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
|
||||
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
|
||||
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, src, src_mask=None, pos_embed=None):
|
||||
q = k = src if pos_embed is None else src + pos_embed
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)
|
||||
src = self.norm1(src + src2)
|
||||
src2 = self.linear2(self.activation(self.linear1(src)))
|
||||
return self.norm2(src + src2)
|
||||
|
||||
|
||||
class _TransformerEncoder(nn.Module):
|
||||
"""Thin wrapper so state-dict keys are encoder.0.layers.N.*"""
|
||||
def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
_TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, src, src_mask=None, pos_embed=None):
|
||||
for layer in self.layers:
|
||||
src = layer(src, src_mask=src_mask, pos_embed=pos_embed)
|
||||
return src
|
||||
|
||||
|
||||
class HybridEncoder(nn.Module):
|
||||
def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1,
|
||||
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_channels = list(in_channels)
|
||||
self.feat_strides = list(feat_strides)
|
||||
self.hidden_dim = hidden_dim
|
||||
self.use_encoder_idx = list(use_encoder_idx)
|
||||
self.pe_temperature = pe_temperature
|
||||
self.eval_spatial_size = eval_spatial_size
|
||||
self.out_channels = [hidden_dim] * len(in_channels)
|
||||
self.out_strides = list(feat_strides)
|
||||
|
||||
# channel projection (expects pre-fused weights)
|
||||
self.input_proj = nn.ModuleList([
|
||||
nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))
|
||||
for ch in in_channels
|
||||
])
|
||||
|
||||
# AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.*
|
||||
self.encoder = nn.ModuleList([
|
||||
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(use_encoder_idx))
|
||||
])
|
||||
|
||||
nb = round(3 * depth_mult)
|
||||
exp = expansion
|
||||
|
||||
# top-down FPN (dfine: lateral conv has no act)
|
||||
self.lateral_convs = nn.ModuleList(
|
||||
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
self.fpn_blocks = nn.ModuleList(
|
||||
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
|
||||
# bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2)
|
||||
self.downsample_convs = nn.ModuleList(
|
||||
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations))
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
self.pan_blocks = nn.ModuleList(
|
||||
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
|
||||
# cache positional embeddings for fixed spatial size
|
||||
if eval_spatial_size:
|
||||
for idx in self.use_encoder_idx:
|
||||
stride = self.feat_strides[idx]
|
||||
pe = self._build_pe(eval_spatial_size[1] // stride,
|
||||
eval_spatial_size[0] // stride,
|
||||
hidden_dim, pe_temperature)
|
||||
setattr(self, f'pos_embed{idx}', pe)
|
||||
|
||||
@staticmethod
|
||||
def _build_pe(w, h, dim=256, temp=10000.):
|
||||
assert dim % 4 == 0
|
||||
gw = torch.arange(w, dtype=torch.float32)
|
||||
gh = torch.arange(h, dtype=torch.float32)
|
||||
gw, gh = torch.meshgrid(gw, gh, indexing='ij')
|
||||
pdim = dim // 4
|
||||
omega = 1. / (temp ** (torch.arange(pdim, dtype=torch.float32) / pdim))
|
||||
ow = gw.flatten()[:, None] @ omega[None]
|
||||
oh = gh.flatten()[:, None] @ omega[None]
|
||||
return torch.cat([ow.sin(), ow.cos(), oh.sin(), oh.cos()], 1)[None]
|
||||
|
||||
def forward(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
|
||||
|
||||
for i, enc_idx in enumerate(self.use_encoder_idx):
|
||||
h, w = proj[enc_idx].shape[2:]
|
||||
src = proj[enc_idx].flatten(2).permute(0, 2, 1)
|
||||
pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype)
|
||||
for layer in self.encoder[i].layers:
|
||||
src = layer(src, pos_embed=pe)
|
||||
proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
|
||||
|
||||
n = len(self.in_channels)
|
||||
inner = [proj[-1]]
|
||||
for k in range(n - 1, 0, -1):
|
||||
j = n - 1 - k
|
||||
top = self.lateral_convs[j](inner[0])
|
||||
inner[0] = top
|
||||
up = F.interpolate(top, scale_factor=2., mode='nearest')
|
||||
inner.insert(0, self.fpn_blocks[j](torch.cat([up, proj[k - 1]], 1)))
|
||||
|
||||
outs = [inner[0]]
|
||||
for k in range(n - 1):
|
||||
outs.append(self.pan_blocks[k](
|
||||
torch.cat([self.downsample_convs[k](outs[-1]), inner[k + 1]], 1)))
|
||||
return outs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder — DFINETransformer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
value : list of per-level tensors [bs*n_head, c, h_l, w_l]
|
||||
sampling_locations: [bs, Lq, n_head, sum(pts), 2] in [0,1]
|
||||
attention_weights : [bs, Lq, n_head, sum(pts)]
|
||||
"""
|
||||
_, c = value[0].shape[:2] # bs*n_head, c
|
||||
_, Lq, n_head, _, _ = sampling_locations.shape
|
||||
bs = sampling_locations.shape[0]
|
||||
n_h = n_head
|
||||
|
||||
grids = (2 * sampling_locations - 1) # [bs, Lq, n_head, sum_pts, 2]
|
||||
grids = grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # [bs*n_head, Lq, sum_pts, 2]
|
||||
grids_per_lvl = grids.split(num_points_list, dim=2) # list of [bs*n_head, Lq, pts_l, 2]
|
||||
|
||||
sampled = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
val_l = value[lvl].reshape(bs * n_h, c, h, w)
|
||||
sv = F.grid_sample(val_l, grids_per_lvl[lvl], mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
sampled.append(sv) # sv: [bs*n_head, c, Lq, pts_l]
|
||||
|
||||
attn = attention_weights.permute(0, 2, 1, 3) # [bs, n_head, Lq, sum_pts]
|
||||
attn = attn.flatten(0, 1).unsqueeze(1) # [bs*n_head, 1, Lq, sum_pts]
|
||||
out = (torch.cat(sampled, -1) * attn).sum(-1) # [bs*n_head, c, Lq]
|
||||
out = out.reshape(bs, n_h * c, Lq)
|
||||
return out.permute(0, 2, 1) # [bs, Lq, hidden]
|
||||
|
||||
|
||||
class MSDeformableAttention(nn.Module):
|
||||
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim, self.num_heads = embed_dim, num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
pts = num_points if isinstance(num_points, list) else [num_points] * num_levels
|
||||
self.num_points_list = pts
|
||||
self.offset_scale = offset_scale
|
||||
total = num_heads * sum(pts)
|
||||
self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32))
|
||||
self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype)
|
||||
self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, query, ref_pts, value, spatial_shapes):
|
||||
bs, Lq = query.shape[:2]
|
||||
offsets = self.sampling_offsets(query).reshape(
|
||||
bs, Lq, self.num_heads, sum(self.num_points_list), 2)
|
||||
attn_w = F.softmax(
|
||||
self.attention_weights(query).reshape(
|
||||
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
|
||||
scale = self.num_points_scale.to(query).unsqueeze(-1)
|
||||
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
|
||||
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
|
||||
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
|
||||
|
||||
|
||||
class Gate(nn.Module):
|
||||
def __init__(self, d_model, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1)
|
||||
return self.norm(g1 * x1 + g2 * x2)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
|
||||
self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.SiLU()(layer(x)) if i < len(self.layers) - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
|
||||
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU()
|
||||
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
|
||||
q = k = target if query_pos is None else target + query_pos
|
||||
t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask)
|
||||
target = self.norm1(target + t2)
|
||||
t2 = self.cross_attn(
|
||||
target if query_pos is None else target + query_pos,
|
||||
ref_pts, value, spatial_shapes)
|
||||
target = self.gateway(target, t2)
|
||||
t2 = self.linear2(self.activation(self.linear1(target)))
|
||||
target = self.norm3((target + t2).clamp(-65504, 65504))
|
||||
return target
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FDR utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def weighting_function(reg_max, up, reg_scale):
|
||||
"""Non-uniform weighting function W(n) for FDR box regression."""
|
||||
ub1 = (abs(up[0]) * abs(reg_scale)).item()
|
||||
ub2 = ub1 * 2
|
||||
step = (ub1 + 1) ** (2 / (reg_max - 2))
|
||||
left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
|
||||
right = [ (step ** i) - 1 for i in range(1, reg_max // 2)]
|
||||
vals = [-ub2] + left + [0] + right + [ub2]
|
||||
return torch.tensor(vals, dtype=up.dtype, device=up.device)
|
||||
|
||||
|
||||
def distance2bbox(points, distance, reg_scale):
|
||||
"""Decode edge-distances → cxcywh boxes."""
|
||||
rs = abs(reg_scale).to(dtype=points.dtype)
|
||||
x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs)
|
||||
y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs)
|
||||
x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs)
|
||||
y2 = points[..., 1] + (0.5 * rs + distance[..., 3]) * (points[..., 3] / rs)
|
||||
x0, y0, x1_, y1_ = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
|
||||
return torch.stack([x0, y0, x1_, y1_], -1)
|
||||
|
||||
|
||||
class Integral(nn.Module):
|
||||
"""Sum Pr(n)·W(n) over the distribution bins."""
|
||||
def __init__(self, reg_max=32):
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
|
||||
def forward(self, x, project):
|
||||
shape = x.shape
|
||||
x = F.softmax(x.reshape(-1, self.reg_max + 1), 1)
|
||||
x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4)
|
||||
return x.reshape(list(shape[:-1]) + [-1])
|
||||
|
||||
|
||||
class LQE(nn.Module):
|
||||
"""Location Quality Estimator — refines class scores using corner distribution."""
|
||||
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.k, self.reg_max = k, reg_max
|
||||
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, scores, pred_corners):
|
||||
B, L, _ = pred_corners.shape
|
||||
prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), -1)
|
||||
topk, _ = prob.topk(self.k, -1)
|
||||
stat = torch.cat([topk, topk.mean(-1, keepdim=True)], -1)
|
||||
return scores + self.reg_conf(stat.reshape(B, L, -1))
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.nhead = nhead
|
||||
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||
self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.eval_idx + 1)
|
||||
])
|
||||
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)])
|
||||
self.register_buffer('project', weighting_function(reg_max, up, reg_scale))
|
||||
|
||||
def _value_op(self, memory, spatial_shapes):
|
||||
"""Reshape memory to per-level value tensors for deformable attention."""
|
||||
c = self.hidden_dim // self.nhead
|
||||
split = [h * w for h, w in spatial_shapes]
|
||||
val = memory.reshape(memory.shape[0], memory.shape[1], self.nhead, c) # memory: [bs, sum(h*w), hidden_dim]
|
||||
# → [bs, n_head, c, sum_hw]
|
||||
val = val.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, sum_hw]
|
||||
return val.split(split, dim=-1) # list of [bs*n_head, c, h_l*w_l]
|
||||
|
||||
def forward(self, target, ref_pts_unact, memory, spatial_shapes, bbox_head, score_head, query_pos_head, pre_bbox_head, integral):
|
||||
val_split_flat = self._value_op(memory, spatial_shapes) # pre-split value for deformable attention
|
||||
|
||||
# reshape to [bs*n_head, c, h_l, w_l]
|
||||
value = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
v = val_split_flat[lvl] # [bs*n_head, c, h*w]
|
||||
value.append(v.reshape(v.shape[0], v.shape[1], h, w))
|
||||
|
||||
ref_pts = F.sigmoid(ref_pts_unact)
|
||||
output = target
|
||||
output_detach = pred_corners_undetach = 0
|
||||
|
||||
dec_bboxes, dec_logits = [], []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
ref_input = ref_pts.unsqueeze(2) # [bs, Lq, 1, 4]
|
||||
query_pos = query_pos_head(ref_pts).clamp(-10, 10)
|
||||
output = layer(output, ref_input, value, spatial_shapes, query_pos=query_pos)
|
||||
|
||||
if i == 0:
|
||||
ref_unact = ref_pts.clamp(1e-5, 1 - 1e-5)
|
||||
ref_unact = torch.log(ref_unact / (1 - ref_unact))
|
||||
pre_bboxes = F.sigmoid(pre_bbox_head(output) + ref_unact)
|
||||
ref_pts_initial = pre_bboxes.detach()
|
||||
|
||||
pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach
|
||||
inter_ref_bbox = distance2bbox(ref_pts_initial, integral(pred_corners, self.project), self.reg_scale)
|
||||
|
||||
if i == self.eval_idx:
|
||||
scores = score_head[i](output)
|
||||
scores = self.lqe_layers[i](scores, pred_corners)
|
||||
dec_bboxes.append(inter_ref_bbox)
|
||||
dec_logits.append(scores)
|
||||
break
|
||||
|
||||
pred_corners_undetach = pred_corners
|
||||
ref_pts = inter_ref_bbox.detach()
|
||||
output_detach = output.detach()
|
||||
|
||||
return torch.stack(dec_bboxes), torch.stack(dec_logits)
|
||||
|
||||
|
||||
class DFINETransformer(nn.Module):
|
||||
def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32],
|
||||
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32,
|
||||
reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert len(feat_strides) == len(feat_channels)
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_queries = num_queries
|
||||
self.num_levels = num_levels
|
||||
self.eps = eps
|
||||
self.eval_spatial_size = eval_spatial_size
|
||||
|
||||
self.feat_strides = list(feat_strides)
|
||||
for i in range(num_levels - len(feat_strides)):
|
||||
self.feat_strides.append(feat_strides[-1] * 2 ** (i + 1))
|
||||
|
||||
# input projection (expects pre-fused weights)
|
||||
self.input_proj = nn.ModuleList()
|
||||
for ch in feat_channels:
|
||||
if ch == hidden_dim:
|
||||
self.input_proj.append(nn.Identity())
|
||||
else:
|
||||
self.input_proj.append(nn.Sequential(OrderedDict([
|
||||
('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])))
|
||||
in_ch = feat_channels[-1]
|
||||
for i in range(num_levels - len(feat_channels)):
|
||||
self.input_proj.append(nn.Sequential(OrderedDict([
|
||||
('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim,
|
||||
hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))])))
|
||||
in_ch = hidden_dim
|
||||
|
||||
# FDR parameters (non-trainable placeholders, set from config)
|
||||
self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
|
||||
self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
|
||||
|
||||
pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels
|
||||
self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts,
|
||||
num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.enc_output = nn.Sequential(OrderedDict([
|
||||
('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)),
|
||||
('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))]))
|
||||
self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype)
|
||||
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||
self.dec_score_head = nn.ModuleList(
|
||||
[operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)])
|
||||
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||
self.dec_bbox_head = nn.ModuleList(
|
||||
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.eval_idx_ + 1)])
|
||||
self.integral = Integral(reg_max)
|
||||
|
||||
if eval_spatial_size:
|
||||
# Register as buffers so checkpoint values override the freshly-computed defaults
|
||||
anchors, valid_mask = self._gen_anchors()
|
||||
self.register_buffer('anchors', anchors)
|
||||
self.register_buffer('valid_mask', valid_mask)
|
||||
|
||||
def _gen_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device='cpu'):
|
||||
if spatial_shapes is None:
|
||||
h0, w0 = self.eval_spatial_size
|
||||
spatial_shapes = [[int(h0 / s), int(w0 / s)] for s in self.feat_strides]
|
||||
anchors = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
|
||||
gxy = (torch.stack([gx, gy], -1).float() + 0.5) / torch.tensor([w, h], dtype=dtype)
|
||||
wh = torch.ones_like(gxy) * grid_size * (2. ** lvl)
|
||||
anchors.append(torch.cat([gxy, wh], -1).reshape(-1, h * w, 4))
|
||||
anchors = torch.cat(anchors, 1).to(device)
|
||||
valid_mask = ((anchors > self.eps) & (anchors < 1 - self.eps)).all(-1, keepdim=True)
|
||||
anchors = torch.log(anchors / (1 - anchors))
|
||||
anchors = torch.where(valid_mask, anchors, torch.full_like(anchors, float('inf')))
|
||||
return anchors, valid_mask
|
||||
|
||||
def _encoder_input(self, feats: List[torch.Tensor]):
|
||||
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
|
||||
for i in range(len(feats), self.num_levels):
|
||||
proj.append(self.input_proj[i](feats[-1] if i == len(feats) else proj[-1]))
|
||||
flat, shapes = [], []
|
||||
for f in proj:
|
||||
_, _, h, w = f.shape
|
||||
flat.append(f.flatten(2).permute(0, 2, 1))
|
||||
shapes.append([h, w])
|
||||
return torch.cat(flat, 1), shapes
|
||||
|
||||
def _decoder_input(self, memory: torch.Tensor):
|
||||
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
|
||||
if memory.shape[0] > 1:
|
||||
anchors = anchors.repeat(memory.shape[0], 1, 1)
|
||||
|
||||
mem = valid_mask.to(memory) * memory
|
||||
out_mem = self.enc_output(mem)
|
||||
logits = self.enc_score_head(out_mem)
|
||||
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
|
||||
idx_e = idx.unsqueeze(-1)
|
||||
topk_mem = out_mem.gather(1, idx_e.expand(-1, -1, out_mem.shape[-1]))
|
||||
topk_anc = anchors.gather(1, idx_e.expand(-1, -1, anchors.shape[-1]))
|
||||
topk_ref = self.enc_bbox_head(topk_mem) + topk_anc
|
||||
return topk_mem.detach(), topk_ref.detach()
|
||||
|
||||
def forward(self, feats: List[torch.Tensor]):
|
||||
memory, shapes = self._encoder_input(feats)
|
||||
content, ref = self._decoder_input(memory)
|
||||
out_bboxes, out_logits = self.decoder(
|
||||
content, ref, memory, shapes,
|
||||
self.dec_bbox_head, self.dec_score_head,
|
||||
self.query_pos_head, self.pre_bbox_head, self.integral)
|
||||
return {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RTv4(nn.Module):
|
||||
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.operations = operations
|
||||
|
||||
self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations)
|
||||
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations)
|
||||
self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries,
|
||||
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_queries = num_queries
|
||||
self.load_device = comfy.model_management.get_torch_device()
|
||||
|
||||
def _forward(self, x: torch.Tensor):
|
||||
return self.decoder(self.encoder(self.backbone(x)))
|
||||
|
||||
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
|
||||
logits = outputs['pred_logits']
|
||||
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
|
||||
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
|
||||
scores = F.sigmoid(logits)
|
||||
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
|
||||
labels = idx % self.num_classes
|
||||
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
|
||||
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
|
||||
|
||||
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
|
||||
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
|
||||
return self.postprocess(outputs, orig_size)
|
||||
@@ -99,7 +99,7 @@ class Resample(nn.Module):
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
@@ -109,22 +109,7 @@ class Resample(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
@@ -145,19 +130,24 @@ class Resample(nn.Module):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
feat_cache[idx] = x
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
cache_x = x[:, :, -1:, :, :]
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
deferred_x = feat_cache[idx + 1]
|
||||
if deferred_x is not None:
|
||||
x = torch.cat([deferred_x, x], 2)
|
||||
feat_cache[idx + 1] = None
|
||||
|
||||
if x.shape[2] == 1 and not final:
|
||||
feat_cache[idx + 1] = x
|
||||
x = None
|
||||
|
||||
feat_idx[0] += 2
|
||||
return x
|
||||
|
||||
|
||||
@@ -177,19 +167,12 @@ class ResidualBlock(nn.Module):
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -213,7 +196,7 @@ class AttentionBlock(nn.Module):
|
||||
self.proj = ops.Conv2d(dim, dim, 1)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
@@ -283,17 +266,10 @@ class Encoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -303,14 +279,16 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if x is None:
|
||||
return None
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -318,14 +296,7 @@ class Encoder3d(nn.Module):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -389,18 +360,48 @@ class Decoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
|
||||
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
|
||||
x = x_ref[0]
|
||||
x_ref[0] = None
|
||||
if layer_idx >= len(self.upsamples):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[feat_idx[0]])
|
||||
feat_cache[feat_idx[0]] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
out_chunks.append(x)
|
||||
return
|
||||
|
||||
layer = self.upsamples[layer_idx]
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
|
||||
for frame_idx in range(0, x.shape[2], 2):
|
||||
self.run_up(
|
||||
layer_idx + 1,
|
||||
[x[:, :, frame_idx:frame_idx + 2, :, :]],
|
||||
feat_cache,
|
||||
feat_idx.copy(),
|
||||
out_chunks,
|
||||
)
|
||||
del x
|
||||
return
|
||||
|
||||
next_x_ref = [x]
|
||||
del x
|
||||
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -409,42 +410,21 @@ class Decoder3d(nn.Module):
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
out_chunks = []
|
||||
|
||||
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
|
||||
return out_chunks
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
def count_cache_layers(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -482,11 +462,12 @@ class WanVAE(nn.Module):
|
||||
conv_idx = [0]
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
t = 1 + ((t - 1) // 4) * 4
|
||||
iter_ = 1 + (t - 1) // 2
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
feat_map = [None] * count_cache_layers(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
@@ -496,20 +477,23 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
feat_idx=conv_idx,
|
||||
final=(i == (iter_ - 1)))
|
||||
if out_ is None:
|
||||
continue
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
# z: [b,c,t,h,w]
|
||||
iter_ = z.shape[2]
|
||||
iter_ = 1 + z.shape[2] // 2
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
feat_map = [None] * count_cache_layers(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@@ -520,8 +504,8 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
return out
|
||||
out += out_
|
||||
return torch.cat(out, 2)
|
||||
|
||||
@@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination):
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size):
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
or not tensor.is_contiguous()):
|
||||
return False
|
||||
|
||||
if info.size == 0:
|
||||
@@ -138,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
|
||||
return dest_views
|
||||
|
||||
aimdo_enabled = False
|
||||
|
||||
extra_ram_release_callback = None
|
||||
RAM_CACHE_HEADROOM = 0
|
||||
|
||||
def set_ram_cache_release_state(callback, headroom):
|
||||
global extra_ram_release_callback
|
||||
global RAM_CACHE_HEADROOM
|
||||
extra_ram_release_callback = callback
|
||||
RAM_CACHE_HEADROOM = max(0, int(headroom))
|
||||
|
||||
def extra_ram_release(target):
|
||||
if extra_ram_release_callback is None:
|
||||
return 0
|
||||
return extra_ram_release_callback(target)
|
||||
|
||||
@@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
@@ -51,6 +52,7 @@ import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -285,6 +287,12 @@ class BaseModel(torch.nn.Module):
|
||||
return data
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
||||
Return a sliced cond object, or None to fall through to default handling.
|
||||
Use comfy.context_windows.slice_cond() for common cases."""
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@@ -883,7 +891,7 @@ class Flux(BaseModel):
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
return kwargs.get("pooled_output", None)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -930,9 +938,10 @@ class LongCatImage(Flux):
|
||||
transformer_options = transformer_options.copy()
|
||||
rope_opts = transformer_options.get("rope_options", {})
|
||||
rope_opts = dict(rope_opts)
|
||||
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
|
||||
rope_opts.setdefault("shift_t", 1.0)
|
||||
rope_opts.setdefault("shift_y", 512.0)
|
||||
rope_opts.setdefault("shift_x", 512.0)
|
||||
rope_opts.setdefault("shift_y", pe_len)
|
||||
rope_opts.setdefault("shift_x", pe_len)
|
||||
transformer_options["rope_options"] = rope_opts
|
||||
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
@@ -1053,6 +1062,10 @@ class LTXAV(BaseModel):
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
if ref_audio is not None:
|
||||
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
@@ -1375,6 +1388,11 @@ class WAN21_Vace(WAN21):
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "vace_context":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
@@ -1427,6 +1445,11 @@ class WAN21_HuMo(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_Animate(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
|
||||
@@ -1444,6 +1467,13 @@ class WAN22_Animate(WAN21):
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "face_pixel_values":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
|
||||
if cond_key == "pose_latents":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
@@ -1480,6 +1510,11 @@ class WAN22_S2V(WAN21):
|
||||
out['reference_motion'] = reference_motion.shape
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
@@ -1923,3 +1958,7 @@ class Kandinsky5Image(Kandinsky5):
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
@@ -696,6 +696,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["audio_model"] = "ace1.5"
|
||||
head_dim = 128
|
||||
dit_config["hidden_size"] = state_dict['{}decoder.layers.0.self_attn_norm.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["intermediate_size"] = state_dict['{}decoder.layers.0.mlp.gate_proj.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["num_heads"] = state_dict['{}decoder.layers.0.self_attn.q_proj.weight'.format(key_prefix)].shape[0] // head_dim
|
||||
|
||||
dit_config["encoder_hidden_size"] = state_dict['{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["encoder_num_heads"] = state_dict['{}encoder.lyric_encoder.layers.0.self_attn.q_proj.weight'.format(key_prefix)].shape[0] // head_dim
|
||||
dit_config["encoder_intermediate_size"] = state_dict['{}encoder.lyric_encoder.layers.0.mlp.gate_proj.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["num_dit_layers"] = count_blocks(state_dict_keys, '{}decoder.layers.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "RT_DETR_v4"
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
|
||||
@@ -55,6 +55,7 @@ total_vram = 0
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
training_fp8_bwd = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
@@ -668,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if device is None or shift_model.device == device:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
@@ -678,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
if not DISABLE_SMART_MEMORY or device is None:
|
||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
@@ -707,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
else:
|
||||
elif device is not None:
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
@@ -1325,9 +1326,9 @@ MAX_PINNED_MEMORY = -1
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
|
||||
else:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
@@ -1402,8 +1403,6 @@ def unpin_memory(tensor):
|
||||
|
||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||
if len(PINNED_MEMORY) == 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
return True
|
||||
else:
|
||||
logging.warning("Unpin error.")
|
||||
|
||||
@@ -300,9 +300,6 @@ class ModelPatcher:
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
|
||||
69
comfy/ops.py
69
comfy/ops.py
@@ -777,8 +777,16 @@ from .quant_ops import (
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||
|
||||
When training_fp8_bwd is enabled:
|
||||
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||
- Cached input is FP8 (half the memory of bf16)
|
||||
|
||||
When training_fp8_bwd is disabled:
|
||||
- Forward: quantize input per layout, use quantized matmul
|
||||
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input (same as inference path)
|
||||
# Quantize input for forward (same layout as weight)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
@@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Restore original input shape
|
||||
# Unflatten output to match original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
# Save for backward
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||
|
||||
if ctx.fp8_bwd:
|
||||
# Cache FP8 quantized input — half the memory of bf16
|
||||
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||
ctx.q_input = q_input # already FP8, reuse
|
||||
else:
|
||||
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||
ctx.save_for_backward(weight)
|
||||
else:
|
||||
ctx.q_input = None
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
# Value casting — only difference between fp8 and non-fp8 paths
|
||||
if ctx.fp8_bwd:
|
||||
weight, = ctx.saved_tensors
|
||||
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||
weight_mm = weight
|
||||
elif isinstance(weight, QuantizedTensor):
|
||||
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
else:
|
||||
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
input_mm = ctx.q_input
|
||||
else:
|
||||
weight_f = weight.to(compute_dtype)
|
||||
input_float, weight = ctx.saved_tensors
|
||||
# Standard tensors → torch.mm does regular matmul
|
||||
grad_mm = grad_2d
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_mm = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_mm = weight.to(compute_dtype)
|
||||
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
# Computation — same for both paths, dispatch handles the rest
|
||||
grad_input = torch.mm(grad_mm, weight_mm)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
@@ -895,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
self.weight = None
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
@@ -1001,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
|
||||
@@ -2,6 +2,7 @@ import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
import psutil
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@@ -12,6 +13,11 @@ def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
|
||||
#we split the difference and assume half the RAM cache headroom is for us
|
||||
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
|
||||
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
|
||||
@@ -8,12 +8,12 @@ import comfy.nested_tensor
|
||||
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return samples
|
||||
|
||||
@@ -985,8 +985,8 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
@@ -1028,6 +1028,7 @@ class CFGGuider:
|
||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||
else:
|
||||
denoise_mask = denoise_masks[0]
|
||||
denoise_mask = denoise_mask.float()
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
|
||||
93
comfy/sd.py
93
comfy/sd.py
@@ -61,6 +61,7 @@ import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -279,9 +280,6 @@ class CLIP:
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.patcher.get_ram_usage()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
@@ -425,13 +423,13 @@ class CLIP:
|
||||
def get_key_patches(self):
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@@ -455,7 +453,7 @@ class VAE:
|
||||
self.output_channels = 3
|
||||
self.pad_channel_value = None
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
@@ -558,12 +556,19 @@ class VAE:
|
||||
old_memory_used_decode = self.memory_used_decode
|
||||
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
|
||||
|
||||
decoder_ch = sd['decoder.conv_in.weight'].shape[0] // ddconfig['ch_mult'][-1]
|
||||
if decoder_ch != ddconfig['ch']:
|
||||
decoder_ddconfig = ddconfig.copy()
|
||||
decoder_ddconfig['ch'] = decoder_ch
|
||||
else:
|
||||
decoder_ddconfig = None
|
||||
|
||||
if 'post_quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1], **({"decoder_ddconfig": decoder_ddconfig} if decoder_ddconfig is not None else {}))
|
||||
else:
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': decoder_ddconfig if decoder_ddconfig is not None else ddconfig})
|
||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||
config = {}
|
||||
param_key = None
|
||||
@@ -839,9 +844,6 @@ class VAE:
|
||||
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
@@ -951,12 +953,23 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
@@ -967,6 +980,7 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
@@ -1027,8 +1041,13 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
@@ -1043,6 +1062,7 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
@@ -1210,6 +1230,11 @@ class TEModel(Enum):
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
GEMMA_3_4B_VISION = 22
|
||||
QWEN35_08B = 23
|
||||
QWEN35_2B = 24
|
||||
QWEN35_4B = 25
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1249,6 +1274,17 @@ def detect_te_model(sd):
|
||||
return TEModel.QWEN25_3B
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
|
||||
weight = sd['model.language_model.layers.0.input_layernorm.weight']
|
||||
if weight.shape[0] == 1024:
|
||||
return TEModel.QWEN35_08B
|
||||
if weight.shape[0] == 2560:
|
||||
return TEModel.QWEN35_4B
|
||||
if weight.shape[0] == 4096:
|
||||
return TEModel.QWEN35_9B
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@@ -1281,11 +1317,12 @@ def t5xxl_detect(clip_data):
|
||||
return {}
|
||||
|
||||
def llama_detect(clip_data):
|
||||
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
||||
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
|
||||
|
||||
for sd in clip_data:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||
for weight_name in weight_names:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||
|
||||
return {}
|
||||
|
||||
@@ -1413,6 +1450,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
@@ -1701,15 +1743,18 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
"""
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
||||
out, pooled = o[:2]
|
||||
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
|
||||
|
||||
if len(o) > 2:
|
||||
extra = {}
|
||||
for k in o[2]:
|
||||
v = o[2][k]
|
||||
if k == "attention_mask":
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
|
||||
extra[k] = v
|
||||
|
||||
r = r + (extra,)
|
||||
@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
||||
if isinstance(tokens, dict):
|
||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||
else:
|
||||
tokens_only = tokens
|
||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
@@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
@@ -1734,6 +1734,21 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
class RT_DETR_v4(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "RT_DETR_v4",
|
||||
}
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.RT_DETR_v4(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
lm_head: bool = True
|
||||
stop_tokens = [151643, 151645]
|
||||
|
||||
@dataclass
|
||||
@@ -655,6 +655,17 @@ class Llama2_(nn.Module):
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
return past_key_values[0][2]
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
return precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
@@ -667,17 +678,12 @@ class Llama2_(nn.Module):
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
past_len = past_key_values[0][2]
|
||||
past_len = self.get_past_len(past_key_values)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=x.device)
|
||||
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
@@ -812,9 +818,16 @@ class BaseGenerate:
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
|
||||
device = embeds.device
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
model_config = self.model.config
|
||||
past_key_values = []
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
stop_tokens = self.model.config.stop_tokens
|
||||
@@ -829,11 +842,8 @@ class BaseGenerate:
|
||||
if embeds.ndim == 2:
|
||||
embeds = embeds.unsqueeze(0)
|
||||
|
||||
past_key_values = [] #kv_cache init
|
||||
max_cache_len = embeds.shape[1] + max_length
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||
|
||||
@@ -844,7 +854,7 @@ class BaseGenerate:
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
@@ -856,7 +866,7 @@ class BaseGenerate:
|
||||
|
||||
return generated_token_ids
|
||||
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
|
||||
|
||||
if not do_sample or temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
@@ -867,6 +877,11 @@ class BaseGenerate:
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if presence_penalty is not None and presence_penalty != 0.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] -= presence_penalty
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
@@ -897,6 +912,9 @@ class BaseGenerate:
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
if self.model.config.lm_head:
|
||||
return self.model.lm_head(input)
|
||||
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
@@ -1028,12 +1046,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
if attention_mask is not None:
|
||||
# Assign compact sequential positions to attended tokens only,
|
||||
# skipping over padding so post-padding tokens aren't inflated.
|
||||
after_mask = attention_mask[0, end:]
|
||||
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
|
||||
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
|
||||
else:
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
|
||||
@@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
|
||||
return [output]
|
||||
|
||||
|
||||
IMAGE_PAD_TOKEN_ID = 151655
|
||||
|
||||
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
||||
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory,
|
||||
@@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
name="qwen25_7b",
|
||||
tokenizer=LongCatImageBaseTokenizer,
|
||||
)
|
||||
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
||||
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
|
||||
skip_template = False
|
||||
if text.startswith("<|im_start|>"):
|
||||
skip_template = True
|
||||
@@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
|
||||
)
|
||||
else:
|
||||
has_images = images is not None and len(images) > 0
|
||||
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
|
||||
|
||||
prefix_ids = base_tok.tokenizer(
|
||||
self.longcat_template_prefix, add_special_tokens=False
|
||||
template_prefix, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
suffix_ids = base_tok.tokenizer(
|
||||
self.longcat_template_suffix, add_special_tokens=False
|
||||
self.SUFFIX, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
|
||||
prompt_tokens = base_tok.tokenize_with_weights(
|
||||
@@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
suffix_pairs = [(t, 1.0) for t in suffix_ids]
|
||||
|
||||
combined = prefix_pairs + prompt_pairs + suffix_pairs
|
||||
|
||||
if has_images:
|
||||
embed_count = 0
|
||||
for i in range(len(combined)):
|
||||
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
|
||||
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
|
||||
embed_count += 1
|
||||
|
||||
tokens = {"qwen25_7b": [combined]}
|
||||
|
||||
return tokens
|
||||
|
||||
@@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
self.dtypes.add(dtype)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||
@@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
|
||||
833
comfy/text_encoders/qwen35.py
Normal file
833
comfy/text_encoders/qwen35.py
Normal file
@@ -0,0 +1,833 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
import math
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.qwen_vl
|
||||
|
||||
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
||||
|
||||
|
||||
def _qwen35_layer_types(n):
|
||||
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
||||
|
||||
@dataclass
|
||||
class Qwen35Config:
|
||||
vocab_size: int = 248320
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 24
|
||||
# Full attention params
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 2
|
||||
head_dim: int = 256
|
||||
partial_rotary_factor: float = 0.25
|
||||
# Linear attention (DeltaNet) params
|
||||
linear_num_key_heads: int = 16
|
||||
linear_num_value_heads: int = 16
|
||||
linear_key_head_dim: int = 128
|
||||
linear_value_head_dim: int = 128
|
||||
conv_kernel_size: int = 4
|
||||
# Shared params
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000000.0
|
||||
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
|
||||
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
|
||||
rms_norm_add: bool = True
|
||||
mlp_activation: str = "silu"
|
||||
qkv_bias: bool = False
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
|
||||
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
|
||||
transformer_type: str = "qwen35_2b"
|
||||
rope_dims: list = None
|
||||
rope_scale: float = None
|
||||
|
||||
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
|
||||
|
||||
QWEN35_MODELS = {
|
||||
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
|
||||
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
|
||||
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
|
||||
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||
}
|
||||
|
||||
|
||||
def _make_config(model_type, config_dict={}):
|
||||
overrides = QWEN35_MODELS.get(model_type, {}).copy()
|
||||
overrides.pop("vision", None)
|
||||
if "num_hidden_layers" in overrides:
|
||||
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
|
||||
overrides.update(config_dict)
|
||||
return Qwen35Config(**overrides)
|
||||
|
||||
|
||||
class RMSNormGated(RMSNorm):
|
||||
def forward(self, x, gate):
|
||||
return super().forward(x) * F.silu(gate.to(x.dtype))
|
||||
|
||||
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
|
||||
initial_dtype = query.dtype
|
||||
query = F.normalize(query, dim=-1)
|
||||
key = F.normalize(key, dim=-1)
|
||||
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size))
|
||||
key = F.pad(key, (0, 0, 0, pad_size))
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
total_sequence_length = sequence_length + pad_size
|
||||
scale = 1 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
||||
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
last_recurrent_state = (
|
||||
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
||||
if initial_state is None
|
||||
else initial_state.to(value)
|
||||
)
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
||||
|
||||
for i in range(0, total_sequence_length // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
||||
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :sequence_length]
|
||||
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
|
||||
|
||||
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
|
||||
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
|
||||
# weight: [channels, kernel_size]
|
||||
state_len = conv_state.shape[-1]
|
||||
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
|
||||
conv_state.copy_(combined[:, :, -state_len:])
|
||||
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
|
||||
if bias is not None:
|
||||
out = out + bias.unsqueeze(0).unsqueeze(-1)
|
||||
return F.silu(out).to(x.dtype)
|
||||
|
||||
|
||||
# GatedDeltaNet - Linear Attention Layer
|
||||
|
||||
class GatedDeltaNet(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
hidden = config.hidden_size
|
||||
self.num_key_heads = config.linear_num_key_heads
|
||||
self.num_value_heads = config.linear_num_value_heads
|
||||
self.key_head_dim = config.linear_key_head_dim
|
||||
self.value_head_dim = config.linear_value_head_dim
|
||||
self.conv_kernel_size = config.conv_kernel_size
|
||||
|
||||
key_dim = self.num_key_heads * self.key_head_dim
|
||||
value_dim = self.num_value_heads * self.value_head_dim
|
||||
self.key_dim = key_dim
|
||||
self.value_dim = value_dim
|
||||
conv_dim = key_dim * 2 + value_dim
|
||||
|
||||
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||
|
||||
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
|
||||
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
|
||||
|
||||
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, past_key_value=None, **kwargs):
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
use_recurrent = (
|
||||
past_key_value is not None
|
||||
and past_key_value[2] > 0
|
||||
and seq_len == 1
|
||||
)
|
||||
|
||||
# Projections (shared)
|
||||
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
|
||||
z = self.in_proj_z(x)
|
||||
b = self.in_proj_b(x)
|
||||
a = self.in_proj_a(x)
|
||||
|
||||
# Conv1d
|
||||
if use_recurrent:
|
||||
recurrent_state, conv_state, step_index = past_key_value
|
||||
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
|
||||
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
|
||||
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
|
||||
else:
|
||||
if past_key_value is not None:
|
||||
recurrent_state, conv_state, step_index = past_key_value
|
||||
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
||||
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
|
||||
# Split QKV and compute beta/g
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
|
||||
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
|
||||
beta = b.sigmoid()
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
|
||||
|
||||
# Delta rule
|
||||
if use_recurrent:
|
||||
# single-token path: work in [B, heads, dim] without seq dim
|
||||
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
|
||||
|
||||
if self.num_value_heads != self.num_key_heads:
|
||||
rep = self.num_value_heads // self.num_key_heads
|
||||
query = query.repeat_interleave(rep, dim=1)
|
||||
key = key.repeat_interleave(rep, dim=1)
|
||||
|
||||
scale = self.key_head_dim ** -0.5
|
||||
q = F.normalize(query.float(), dim=-1) * scale
|
||||
k = F.normalize(key.float(), dim=-1)
|
||||
v = value.float()
|
||||
beta_t = beta.reshape(batch_size, -1)
|
||||
g_t = g.reshape(batch_size, -1).exp()
|
||||
|
||||
# In-place state update: [B, heads, k_dim, v_dim]
|
||||
recurrent_state.mul_(g_t[:, :, None, None])
|
||||
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
|
||||
delta = (v - kv_mem) * beta_t[:, :, None]
|
||||
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
|
||||
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
|
||||
|
||||
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
|
||||
present_key_value = (recurrent_state, conv_state, step_index + 1)
|
||||
else:
|
||||
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
|
||||
|
||||
if self.num_value_heads != self.num_key_heads:
|
||||
rep = self.num_value_heads // self.num_key_heads
|
||||
query = query.repeat_interleave(rep, dim=2)
|
||||
key = key.repeat_interleave(rep, dim=2)
|
||||
|
||||
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||
query, key, value, g=g, beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=past_key_value is not None,
|
||||
)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
if last_recurrent_state is not None:
|
||||
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
|
||||
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
|
||||
|
||||
# Gated norm + output projection (shared)
|
||||
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
|
||||
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
|
||||
return output, present_key_value
|
||||
|
||||
|
||||
# GatedAttention - Full Attention with output gating
|
||||
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
|
||||
"""Compute RoPE frequencies for partial rotary embeddings."""
|
||||
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
if mrope_section is not None and position_ids.shape[0] == 3:
|
||||
mrope_section_2 = [s * 2 for s in mrope_section]
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
|
||||
|
||||
|
||||
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
|
||||
"""Apply RoPE to only the first rotary_dim dimensions."""
|
||||
xq_rot = xq[..., :rotary_dim]
|
||||
xq_pass = xq[..., rotary_dim:]
|
||||
xk_rot = xk[..., :rotary_dim]
|
||||
xk_pass = xk[..., rotary_dim:]
|
||||
|
||||
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
|
||||
|
||||
xq = torch.cat([xq_rot, xq_pass], dim=-1)
|
||||
xk = torch.cat([xk_rot, xk_pass], dim=-1)
|
||||
return xq, xk
|
||||
|
||||
|
||||
class GatedAttention(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.hidden_size = config.hidden_size
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
||||
|
||||
# q_proj outputs 2x: query + gate
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
# QK norms with (1+weight) scaling
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||
batch_size, seq_length, _ = x.shape
|
||||
|
||||
# Project Q (with gate), K, V
|
||||
qg = self.q_proj(x)
|
||||
# Split into query and gate: each is [B, seq, inner_size]
|
||||
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
|
||||
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
|
||||
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
|
||||
|
||||
xk = self.k_proj(x)
|
||||
xv = self.v_proj(x)
|
||||
|
||||
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
|
||||
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
|
||||
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply partial RoPE
|
||||
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
|
||||
|
||||
# KV cache
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
past_key, past_value, index = past_key_value
|
||||
num_tokens = xk.shape[2]
|
||||
if past_key.shape[2] >= (index + num_tokens):
|
||||
past_key[:, :, index:index + num_tokens] = xk
|
||||
past_value[:, :, index:index + num_tokens] = xv
|
||||
xk = past_key[:, :, :index + num_tokens]
|
||||
xv = past_value[:, :, :index + num_tokens]
|
||||
present_key_value = (past_key, past_value, index + num_tokens)
|
||||
else:
|
||||
if index > 0:
|
||||
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
# Expand KV heads for GQA
|
||||
if self.num_heads != self.num_kv_heads:
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
output = output * gate.sigmoid()
|
||||
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
|
||||
# Hybrid Transformer Block
|
||||
class Qwen35TransformerBlock(nn.Module):
|
||||
def __init__(self, config, index, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[index]
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
|
||||
else:
|
||||
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||
if self.layer_type == "linear_attention":
|
||||
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
|
||||
else:
|
||||
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
|
||||
|
||||
x = x + h
|
||||
x = x + self.mlp(self.post_attention_layernorm(x))
|
||||
return x, present_key_value
|
||||
|
||||
|
||||
# Qwen35 Transformer Backbone
|
||||
class Qwen35Transformer(Llama2_):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.normalize_in = False
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if layer.layer_type == "full_attention":
|
||||
if len(past_key_values) > i:
|
||||
return past_key_values[i][2]
|
||||
break
|
||||
return 0
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
|
||||
return precompute_partial_rope(
|
||||
self.config.head_dim, rotary_dim, position_ids,
|
||||
self.config.rope_theta, device=device,
|
||||
mrope_section=self.config.mrope_section,
|
||||
)
|
||||
|
||||
|
||||
# Vision Encoder
|
||||
class Qwen35VisionPatchEmbed(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.patch_size = config["patch_size"]
|
||||
self.temporal_patch_size = config["temporal_patch_size"]
|
||||
self.in_channels = config["in_channels"]
|
||||
self.embed_dim = config["hidden_size"]
|
||||
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
||||
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
target_dtype = self.proj.weight.dtype
|
||||
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
|
||||
|
||||
|
||||
class Qwen35VisionMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
|
||||
|
||||
|
||||
class Qwen35VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta=10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, seqlen):
|
||||
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class Qwen35VisionAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||
seq_length = x.shape[0]
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
|
||||
|
||||
# Process per-sequence attention
|
||||
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_splits = torch.split(query_states, lengths, dim=0)
|
||||
k_splits = torch.split(key_states, lengths, dim=0)
|
||||
v_splits = torch.split(value_states, lengths, dim=0)
|
||||
|
||||
attn_outputs = []
|
||||
for q, k, v in zip(q_splits, k_splits, v_splits):
|
||||
q = q.transpose(0, 1).unsqueeze(0)
|
||||
k = k.transpose(0, 1).unsqueeze(0)
|
||||
v = v.transpose(0, 1).unsqueeze(0)
|
||||
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
||||
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
return self.proj(attn_output)
|
||||
|
||||
|
||||
class Qwen35VisionBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class Qwen35VisionPatchMerger(nn.Module):
|
||||
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
merge_dim = hidden_size * (spatial_merge_size ** 2)
|
||||
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
|
||||
self.merge_dim = merge_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x).view(-1, self.merge_dim)
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
||||
|
||||
|
||||
class Qwen35VisionModel(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = config["spatial_merge_size"]
|
||||
self.patch_size = config["patch_size"]
|
||||
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.num_heads = config["num_heads"]
|
||||
self.num_position_embeddings = config["num_position_embeddings"]
|
||||
|
||||
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
|
||||
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
|
||||
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
|
||||
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config["depth"])
|
||||
])
|
||||
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
merge_size = self.spatial_merge_size
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||
freq_table = self.rotary_pos_emb(max_hw)
|
||||
device = freq_table.device
|
||||
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
|
||||
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||
offset = 0
|
||||
for num_frames, height, width in grid_thw_list:
|
||||
num_frames, height, width = int(num_frames), int(height), int(width)
|
||||
merged_h, merged_w = height // merge_size, width // merge_size
|
||||
block_rows = torch.arange(merged_h, device=device)
|
||||
block_cols = torch.arange(merged_w, device=device)
|
||||
intra_row = torch.arange(merge_size, device=device)
|
||||
intra_col = torch.arange(merge_size, device=device)
|
||||
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
||||
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
||||
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||
if num_frames > 1:
|
||||
coords = coords.repeat(num_frames, 1)
|
||||
num_tokens = coords.shape[0]
|
||||
pos_ids[offset:offset + num_tokens] = coords
|
||||
offset += num_tokens
|
||||
embeddings = freq_table[pos_ids]
|
||||
embeddings = embeddings.flatten(1)
|
||||
return embeddings
|
||||
|
||||
def fast_pos_embed_interpolate(self, grid_thw):
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_ts = [int(row[0]) for row in grid_thw_list]
|
||||
grid_hs = [int(row[1]) for row in grid_thw_list]
|
||||
grid_ws = [int(row[2]) for row in grid_thw_list]
|
||||
device = self.pos_embed.weight.device
|
||||
idx_list = [[] for _ in range(4)]
|
||||
weight_list = [[] for _ in range(4)]
|
||||
for t, h, w in grid_thw_list:
|
||||
h, w = int(h), int(w)
|
||||
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
|
||||
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
|
||||
h_idxs_floor = h_idxs.int()
|
||||
w_idxs_floor = w_idxs.int()
|
||||
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
dh = h_idxs - h_idxs_floor
|
||||
dw = w_idxs - w_idxs_floor
|
||||
base_h = h_idxs_floor * self.num_grid_per_side
|
||||
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
||||
indices = [
|
||||
(base_h[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h[None].T + w_idxs_ceil[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
|
||||
]
|
||||
weights = [
|
||||
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||
((1 - dh)[None].T * dw[None]).flatten(),
|
||||
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||
(dh[None].T * dw[None]).flatten(),
|
||||
]
|
||||
for j in range(4):
|
||||
idx_list[j].extend(indices[j].tolist())
|
||||
weight_list[j].extend(weights[j].tolist())
|
||||
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
|
||||
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
|
||||
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
|
||||
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||
patch_pos_embeds_permute = []
|
||||
merge_size = self.spatial_merge_size
|
||||
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
|
||||
pos_embed = pos_embed.repeat(t, 1)
|
||||
pos_embed = (
|
||||
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.flatten(0, 4)
|
||||
)
|
||||
patch_pos_embeds_permute.append(pos_embed)
|
||||
return torch.cat(patch_pos_embeds_permute)
|
||||
|
||||
def forward(self, x, grid_thw):
|
||||
x = self.patch_embed(x)
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
||||
x = x + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
seq_len = x.shape[0]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().unsqueeze(-2)
|
||||
sin = emb.sin().unsqueeze(-2)
|
||||
sin_half = sin.shape[-1] // 2
|
||||
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
merged = self.merger(x)
|
||||
return merged
|
||||
|
||||
# Model Wrapper
|
||||
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
model_type = "qwen35_2b"
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = _make_config(self.model_type, config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
|
||||
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
|
||||
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
|
||||
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
|
||||
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
model_config = self.model.config
|
||||
past_key_values = []
|
||||
for i in range(model_config.num_hidden_layers):
|
||||
if model_config.layer_types[i] == "linear_attention":
|
||||
recurrent_state = torch.zeros(
|
||||
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
|
||||
device=device, dtype=torch.float32
|
||||
)
|
||||
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
|
||||
conv_state = torch.zeros(
|
||||
[batch, conv_dim, model_config.conv_kernel_size - 1],
|
||||
device=device, dtype=execution_dtype
|
||||
)
|
||||
past_key_values.append((recurrent_state, conv_state, 0))
|
||||
else:
|
||||
past_key_values.append((
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
0
|
||||
))
|
||||
return past_key_values
|
||||
|
||||
# Tokenizer and Text Encoder Wrappers
|
||||
|
||||
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
|
||||
from transformers import Qwen2Tokenizer
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
|
||||
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
|
||||
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||
image = kwargs.get("image", None)
|
||||
if image is not None and len(images) == 0:
|
||||
images = [image]
|
||||
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
skip_template = True
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is None:
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
if not thinking:
|
||||
llama_text += "<think>\n</think>\n"
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 248056: # <|image_pad|>
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
class Qwen35ClipModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
|
||||
class Qwen35_(Qwen35):
|
||||
pass
|
||||
Qwen35_.model_type = model_type
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
|
||||
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Qwen35TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
|
||||
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def tokenizer(model_type="qwen35_2b"):
|
||||
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
||||
return Qwen35ImageTokenizer_
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
|
||||
class Qwen35TEModel_(Qwen35TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
||||
return Qwen35TEModel_
|
||||
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
File diff suppressed because one or more lines are too long
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module):
|
||||
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
# Potentially important for spatially precise edits. This is present in the HF implementation.
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
hidden_states = hidden_states[reverse_indices, :]
|
||||
return hidden_states
|
||||
|
||||
@@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out = output[b:b+1].zero_()
|
||||
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
@@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||
@@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
output[b:b+1] = out/out_div
|
||||
out.div_(out_div)
|
||||
return output
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
|
||||
@@ -5,6 +5,11 @@ from comfy_api.latest._input import (
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
CurvePoint,
|
||||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
RangeInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -13,4 +18,9 @@ __all__ = [
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
"RangeInput",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .range_types import RangeInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
@@ -7,4 +9,9 @@ __all__ = [
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
"RangeInput",
|
||||
]
|
||||
|
||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CurvePoint = tuple[float, float]
|
||||
|
||||
|
||||
class CurveInput(ABC):
|
||||
"""Abstract base class for curve inputs.
|
||||
|
||||
Subclasses represent different curve representations (control-point
|
||||
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||
uniform evaluation interface to downstream nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def points(self) -> list[CurvePoint]:
|
||||
"""The control points that define this curve."""
|
||||
|
||||
@abstractmethod
|
||||
def interp(self, x: float) -> float:
|
||||
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||
|
||||
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||
"""Vectorised evaluation over a numpy array of x values.
|
||||
|
||||
Subclasses should override this for better performance. The default
|
||||
falls back to scalar ``interp`` calls.
|
||||
"""
|
||||
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> CurveInput:
|
||||
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||
|
||||
Accepts:
|
||||
- A ``CurveInput`` instance (returned as-is).
|
||||
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||
"""
|
||||
if isinstance(data, CurveInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
raw_points = data["points"]
|
||||
interpolation = data.get("interpolation", "monotone_cubic")
|
||||
else:
|
||||
raw_points = data
|
||||
interpolation = "monotone_cubic"
|
||||
points = [(float(x), float(y)) for x, y in raw_points]
|
||||
if interpolation == "linear":
|
||||
return LinearCurve(points)
|
||||
if interpolation != "monotone_cubic":
|
||||
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||
return MonotoneCubicCurve(points)
|
||||
|
||||
|
||||
class MonotoneCubicCurve(CurveInput):
|
||||
"""Monotone cubic Hermite interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||
backend evaluation matches the editor preview exactly.
|
||||
|
||||
All heavy work (sorting, slope computation) happens once at construction.
|
||||
``interp_array`` is fully vectorised with numpy.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
self._slopes = self._compute_slopes()
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def _compute_slopes(self) -> np.ndarray:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return np.zeros(n, dtype=np.float64)
|
||||
|
||||
dx = np.diff(xs)
|
||||
dy = np.diff(ys)
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||
|
||||
slopes = np.empty(n, dtype=np.float64)
|
||||
slopes[0] = deltas[0]
|
||||
slopes[-1] = deltas[-1]
|
||||
for i in range(1, n - 1):
|
||||
if deltas[i - 1] * deltas[i] <= 0:
|
||||
slopes[i] = 0.0
|
||||
else:
|
||||
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||
|
||||
for i in range(n - 1):
|
||||
if deltas[i] == 0:
|
||||
slopes[i] = 0.0
|
||||
slopes[i + 1] = 0.0
|
||||
else:
|
||||
alpha = slopes[i] / deltas[i]
|
||||
beta = slopes[i + 1] / deltas[i]
|
||||
s = alpha * alpha + beta * beta
|
||||
if s > 9:
|
||||
t = 3 / math.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
return slopes
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
if x <= xs[0]:
|
||||
return float(ys[0])
|
||||
if x >= xs[-1]:
|
||||
return float(ys[-1])
|
||||
|
||||
hi = int(np.searchsorted(xs, x, side='right'))
|
||||
hi = min(hi, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
if dx == 0:
|
||||
return float(ys[lo])
|
||||
|
||||
t = (x - xs[lo]) / dx
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
"""Fully vectorised evaluation using numpy."""
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if n == 1:
|
||||
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||
|
||||
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
|
||||
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MonotoneCubicCurve(points={self._points})"
|
||||
|
||||
|
||||
class LinearCurve(CurveInput):
|
||||
"""Piecewise linear interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createLinearInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
return float(np.interp(x, xs, ys))
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
if len(self._xs) == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if len(self._xs) == 1:
|
||||
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||
return np.interp(xs_in, self._xs, self._ys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LinearCurve(points={self._points})"
|
||||
70
comfy_api/latest/_input/range_types.py
Normal file
70
comfy_api/latest/_input/range_types.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RangeInput:
|
||||
"""Represents a levels/range adjustment: input range [min, max] with
|
||||
optional midpoint (gamma control).
|
||||
|
||||
Generates a 1D LUT identical to GIMP's levels mapping:
|
||||
1. Normalize input to [0, 1] using [min, max]
|
||||
2. Apply gamma correction: pow(value, 1/gamma)
|
||||
3. Clamp to [0, 1]
|
||||
|
||||
The midpoint field is a position in [0, 1] representing where the
|
||||
midtone falls within [min, max]. It maps to gamma via:
|
||||
gamma = -log2(midpoint)
|
||||
So midpoint=0.5 → gamma=1.0 (linear).
|
||||
"""
|
||||
|
||||
def __init__(self, min_val: float, max_val: float, midpoint: float | None = None):
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.midpoint = midpoint
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> RangeInput:
|
||||
if isinstance(data, RangeInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
return RangeInput(
|
||||
min_val=float(data.get("min", 0.0)),
|
||||
max_val=float(data.get("max", 1.0)),
|
||||
midpoint=float(data["midpoint"]) if data.get("midpoint") is not None else None,
|
||||
)
|
||||
raise TypeError(f"Cannot convert {type(data)} to RangeInput")
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table mapping [0, 1] input through this
|
||||
levels adjustment.
|
||||
|
||||
The LUT maps normalized input values (0..1) to output values (0..1),
|
||||
matching the GIMP levels formula.
|
||||
"""
|
||||
xs = np.linspace(0.0, 1.0, size, dtype=np.float64)
|
||||
|
||||
in_range = self.max_val - self.min_val
|
||||
if abs(in_range) < 1e-10:
|
||||
return np.where(xs >= self.min_val, 1.0, 0.0).astype(np.float64)
|
||||
|
||||
# Normalize: map [min, max] → [0, 1]
|
||||
result = (xs - self.min_val) / in_range
|
||||
result = np.clip(result, 0.0, 1.0)
|
||||
|
||||
# Gamma correction from midpoint
|
||||
if self.midpoint is not None and self.midpoint > 0 and self.midpoint != 0.5:
|
||||
gamma = max(-math.log2(self.midpoint), 0.001)
|
||||
inv_gamma = 1.0 / gamma
|
||||
mask = result > 0
|
||||
result[mask] = np.power(result[mask], inv_gamma)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
mid = f", midpoint={self.midpoint}" if self.midpoint is not None else ""
|
||||
return f"RangeInput(min={self.min_val}, max={self.max_val}{mid})"
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from comfy.samplers import CFGGuider, Sampler
|
||||
from comfy.sd import CLIP, VAE
|
||||
from comfy.sd import StyleModel as StyleModel_
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
@@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
from comfy_api.input import CurvePoint
|
||||
if TYPE_CHECKING:
|
||||
Type = CurveInput_
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
@@ -1252,6 +1253,55 @@ class Curve(ComfyTypeIO):
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.default is not None:
|
||||
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
Type = list[int]
|
||||
|
||||
|
||||
@comfytype(io_type="RANGE")
|
||||
class Range(ComfyTypeIO):
|
||||
from comfy_api.input import RangeInput
|
||||
if TYPE_CHECKING:
|
||||
Type = RangeInput
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: dict=None,
|
||||
display: str=None,
|
||||
gradient_stops: list=None,
|
||||
show_midpoint: bool=None,
|
||||
midpoint_scale: str=None,
|
||||
value_min: float=None,
|
||||
value_max: float=None,
|
||||
advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
if default is None:
|
||||
self.default = {"min": 0.0, "max": 1.0}
|
||||
self.display = display
|
||||
self.gradient_stops = gradient_stops
|
||||
self.show_midpoint = show_midpoint
|
||||
self.midpoint_scale = midpoint_scale
|
||||
self.value_min = value_min
|
||||
self.value_max = value_max
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"display": self.display,
|
||||
"gradient_stops": self.gradient_stops,
|
||||
"show_midpoint": self.show_midpoint,
|
||||
"midpoint_scale": self.midpoint_scale,
|
||||
"value_min": self.value_min,
|
||||
"value_max": self.value_max,
|
||||
})
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
@@ -1360,6 +1410,7 @@ class NodeInfoV1:
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
essentials_category: str=None
|
||||
has_intermediate_output: bool=None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1483,6 +1534,16 @@ class Schema:
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
essentials_category: str | None = None
|
||||
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
|
||||
has_intermediate_output: bool=False
|
||||
"""Flags this node as having intermediate output that should persist across page refreshes.
|
||||
|
||||
Nodes with this flag behave like output nodes (their UI results are cached and resent
|
||||
to the frontend) but do NOT automatically get added to the execution list. This means
|
||||
they will only execute if they are on the dependency path of a real output node.
|
||||
|
||||
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
|
||||
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
|
||||
"""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@@ -1582,6 +1643,7 @@ class Schema:
|
||||
category=self.category,
|
||||
description=self.description,
|
||||
output_node=self.is_output_node,
|
||||
has_intermediate_output=self.has_intermediate_output,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
dev_only=self.is_dev_only,
|
||||
@@ -1873,6 +1935,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._OUTPUT_NODE
|
||||
|
||||
_HAS_INTERMEDIATE_OUTPUT = None
|
||||
@final
|
||||
@classproperty
|
||||
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
|
||||
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._HAS_INTERMEDIATE_OUTPUT
|
||||
|
||||
_INPUT_IS_LIST = None
|
||||
@final
|
||||
@classproperty
|
||||
@@ -1965,6 +2035,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._API_NODE = schema.is_api_node
|
||||
if cls._OUTPUT_NODE is None:
|
||||
cls._OUTPUT_NODE = schema.is_output_node
|
||||
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
|
||||
if cls._INPUT_IS_LIST is None:
|
||||
cls._INPUT_IS_LIST = schema.is_input_list
|
||||
if cls._NOT_IDEMPOTENT is None:
|
||||
@@ -2240,5 +2312,7 @@ __all__ = [
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"Range",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
fileData: GeminiFileData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
thought: bool | None = Field(None)
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModel):
|
||||
|
||||
@@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
image: InputUrlObject | None = Field(...)
|
||||
image: InputUrlObject | None = Field(None)
|
||||
reference_images: list[InputUrlObject] | None = Field(None)
|
||||
duration: int = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class VideoExtensionRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
video: InputUrlObject = Field(...)
|
||||
duration: int = Field(default=6)
|
||||
model: str | None = Field(default=None)
|
||||
|
||||
|
||||
class VideoEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
|
||||
43
comfy_api_nodes/apis/quiver.py
Normal file
43
comfy_api_nodes/apis/quiver.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class QuiverImageObject(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class QuiverTextToSVGRequest(BaseModel):
|
||||
model: str = Field(default="arrow-preview")
|
||||
prompt: str = Field(...)
|
||||
instructions: str | None = Field(default=None)
|
||||
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
|
||||
temperature: float | None = Field(default=None, ge=0, le=2)
|
||||
top_p: float | None = Field(default=None, ge=0, le=1)
|
||||
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
|
||||
|
||||
|
||||
class QuiverImageToSVGRequest(BaseModel):
|
||||
model: str = Field(default="arrow-preview")
|
||||
image: QuiverImageObject = Field(...)
|
||||
auto_crop: bool | None = Field(default=None)
|
||||
target_size: int | None = Field(default=None, ge=128, le=4096)
|
||||
temperature: float | None = Field(default=None, ge=0, le=2)
|
||||
top_p: float | None = Field(default=None, ge=0, le=1)
|
||||
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
|
||||
|
||||
|
||||
class QuiverSVGResponseItem(BaseModel):
|
||||
svg: str = Field(...)
|
||||
mime_type: str | None = Field(default="image/svg+xml")
|
||||
|
||||
|
||||
class QuiverSVGUsage(BaseModel):
|
||||
total_tokens: int | None = Field(default=None)
|
||||
input_tokens: int | None = Field(default=None)
|
||||
output_tokens: int | None = Field(default=None)
|
||||
|
||||
|
||||
class QuiverSVGResponse(BaseModel):
|
||||
id: str | None = Field(default=None)
|
||||
created: int | None = Field(default=None)
|
||||
data: list[QuiverSVGResponseItem] = Field(...)
|
||||
usage: QuiverSVGUsage | None = Field(default=None)
|
||||
226
comfy_api_nodes/apis/wan.py
Normal file
226
comfy_api_nodes/apis/wan.py
Normal file
@@ -0,0 +1,226 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Text2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
|
||||
|
||||
class Image2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
images: list[str] = Field(..., min_length=1, max_length=2)
|
||||
|
||||
|
||||
class Text2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
audio_url: str | None = Field(None)
|
||||
|
||||
|
||||
class Image2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
img_url: str = Field(...)
|
||||
audio_url: str | None = Field(None)
|
||||
|
||||
|
||||
class Reference2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
reference_video_urls: list[str] = Field(...)
|
||||
|
||||
|
||||
class Txt2ImageParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Image2ImageParametersField(BaseModel):
|
||||
size: str | None = Field(None)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Text2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||
shot_type: str = Field("single")
|
||||
|
||||
|
||||
class Image2VideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||
shot_type: str = Field("single")
|
||||
|
||||
|
||||
class Reference2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
shot_type: str = Field("single")
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2ImageInputField = Field(...)
|
||||
parameters: Txt2ImageParametersField = Field(...)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Image2ImageInputField = Field(...)
|
||||
parameters: Image2ImageParametersField = Field(...)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2VideoInputField = Field(...)
|
||||
parameters: Text2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Image2VideoInputField = Field(...)
|
||||
parameters: Image2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Reference2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Reference2VideoInputField = Field(...)
|
||||
parameters: Reference2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Wan27MediaItem(BaseModel):
|
||||
type: str = Field(...)
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class Wan27ReferenceVideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
media: list[Wan27MediaItem] = Field(...)
|
||||
|
||||
|
||||
class Wan27ReferenceVideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
ratio: str | None = Field(None)
|
||||
duration: int = Field(5, ge=2, le=10)
|
||||
watermark: bool = Field(False)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
|
||||
|
||||
class Wan27ReferenceVideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Wan27ReferenceVideoInputField = Field(...)
|
||||
parameters: Wan27ReferenceVideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Wan27ImageToVideoInputField(BaseModel):
|
||||
prompt: str | None = Field(None)
|
||||
negative_prompt: str | None = Field(None)
|
||||
media: list[Wan27MediaItem] = Field(...)
|
||||
|
||||
|
||||
class Wan27ImageToVideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
duration: int = Field(5, ge=2, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
|
||||
|
||||
class Wan27ImageToVideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Wan27ImageToVideoInputField = Field(...)
|
||||
parameters: Wan27ImageToVideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Wan27VideoEditInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
media: list[Wan27MediaItem] = Field(...)
|
||||
|
||||
|
||||
class Wan27VideoEditParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
ratio: str | None = Field(None)
|
||||
duration: int = Field(0)
|
||||
audio_setting: str = Field("auto")
|
||||
watermark: bool = Field(False)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
|
||||
|
||||
class Wan27VideoEditTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Wan27VideoEditInputField = Field(...)
|
||||
parameters: Wan27VideoEditParametersField = Field(...)
|
||||
|
||||
|
||||
class Wan27Text2VideoParametersField(BaseModel):
|
||||
resolution: str = Field(...)
|
||||
ratio: str | None = Field(None)
|
||||
duration: int = Field(5, ge=2, le=15)
|
||||
prompt_extend: bool = Field(True)
|
||||
watermark: bool = Field(False)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
|
||||
|
||||
class Wan27Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2VideoInputField = Field(...)
|
||||
parameters: Wan27Text2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class TaskCreationOutputField(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
output: TaskCreationOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
code: str | None = Field(None, description="Error code for the failed request.")
|
||||
message: str | None = Field(None, description="Details about the failed request.")
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
url: str | None = Field(None)
|
||||
code: str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
results: list[TaskResult] | None = Field(None)
|
||||
|
||||
|
||||
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
video_url: str | None = Field(None)
|
||||
code: str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageTaskStatusResponse(BaseModel):
|
||||
output: ImageTaskStatusOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
class VideoTaskStatusResponse(BaseModel):
|
||||
output: VideoTaskStatusOutputField | None = Field(None)
|
||||
request_id: str = Field(...)
|
||||
@@ -47,6 +47,10 @@ SEEDREAM_MODELS = {
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
@@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03}""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
]
|
||||
return await process_video_task(
|
||||
cls,
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x),
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@@ -952,6 +957,12 @@ async def process_video_task(
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
if payload.model in DEPRECATED_MODELS:
|
||||
logger.warning(
|
||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||
payload.model,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
|
||||
@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
||||
$m := widgets.model;
|
||||
$r := widgets.resolution;
|
||||
$isFlash := $contains($m, "nano banana 2");
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
||||
$prices := $isFlash ? $flashPrices : $proPrices;
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
@@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/*")
|
||||
for part in parts:
|
||||
if (part.thought is True) != thought:
|
||||
continue
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
@@ -199,6 +201,16 @@ async def get_image_from_response(response: GeminiGenerateContentResponse) -> In
|
||||
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
if not thought:
|
||||
# No images generated --> extract text response for a meaningful error
|
||||
model_message = get_text_from_response(response).strip()
|
||||
if model_message:
|
||||
raise ValueError(f"Gemini did not generate an image. Model response: {model_message}")
|
||||
raise ValueError(
|
||||
"Gemini did not generate an image. "
|
||||
"Try rephrasing your prompt or changing the response modality to 'IMAGE+TEXT' "
|
||||
"to see the model's reasoning."
|
||||
)
|
||||
return torch.zeros((1, 1024, 1024, 4))
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
@@ -931,6 +943,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
IO.Image.Output(
|
||||
display_name="thought_image",
|
||||
tooltip="First image from the model's thinking process. "
|
||||
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -992,7 +1009,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
return IO.NodeOutput(
|
||||
await get_image_from_response(response),
|
||||
get_text_from_response(response),
|
||||
await get_image_from_response(response, thought=True),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
|
||||
@@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
|
||||
ImageGenerationResponse,
|
||||
InputUrlObject,
|
||||
VideoEditRequest,
|
||||
VideoExtensionRequest,
|
||||
VideoGenerationRequest,
|
||||
VideoGenerationResponse,
|
||||
VideoStatusResponse,
|
||||
@@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
@@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_grok_video_price(response) -> float | None:
|
||||
price = _extract_grok_price(response)
|
||||
if price is not None:
|
||||
return price * 1.43
|
||||
return None
|
||||
|
||||
|
||||
class GrokImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "grok-imagine-video-beta":
|
||||
model = "grok-imagine-video"
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
@@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoReferenceNode",
|
||||
display_name="Grok Reference-to-Video",
|
||||
category="api node/video/Grok",
|
||||
description="Generate video guided by reference images as style and content references.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="reference_",
|
||||
min=1,
|
||||
max=7,
|
||||
),
|
||||
tooltip="Up to 7 reference images to guide the video generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=6,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="The duration of the output video in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model.duration", "model.resolution"],
|
||||
input_groups=["model.reference_images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$refs := inputGroups["model.reference_images"];
|
||||
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||
{"type":"usd","usd": $price}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
ref_image_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
list(model["reference_images"].values()),
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading base images",
|
||||
max_images=7,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model["model"],
|
||||
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||
prompt=prompt,
|
||||
resolution=model["resolution"],
|
||||
duration=model["duration"],
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoExtendNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoExtendNode",
|
||||
display_name="Grok Video Extend",
|
||||
category="api node/video/Grok",
|
||||
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of what should happen next in the video.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=8,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Length of the extension in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video extension.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||
video_size = get_fs_object_size(video.get_stream_source())
|
||||
if video_size > 50 * 1024 * 1024:
|
||||
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||
data=VideoExtensionRequest(
|
||||
prompt=prompt,
|
||||
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||
duration=model["duration"],
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokVideoNode,
|
||||
GrokVideoReferenceNode,
|
||||
GrokVideoEditNode,
|
||||
GrokVideoExtendNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
tooltip="The LowPoly option is unavailable for the `3.1` model.",
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."),
|
||||
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
|
||||
IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
|
||||
IO.DynamicCombo.Input(
|
||||
"generate_type",
|
||||
options=[
|
||||
@@ -251,7 +251,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
IO.Image.Input("image_left", optional=True),
|
||||
IO.Image.Input("image_right", optional=True),
|
||||
IO.Image.Input("image_back", optional=True),
|
||||
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
|
||||
IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
|
||||
IO.DynamicCombo.Input(
|
||||
"generate_type",
|
||||
options=[
|
||||
@@ -422,6 +422,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
outputs=[
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
IO.Image.Output(display_name="uv_image"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -468,9 +469,16 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
uv_image_file = get_file_from_response(result.ResultFile3Ds, "uv_image", raise_if_not_found=False)
|
||||
uv_image = (
|
||||
await download_url_to_image_tensor(uv_image_file.Url)
|
||||
if uv_image_file is not None
|
||||
else torch.zeros(1, 1, 1, 3)
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
uv_image,
|
||||
)
|
||||
|
||||
|
||||
|
||||
291
comfy_api_nodes/nodes_quiver.py
Normal file
291
comfy_api_nodes/nodes_quiver.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from io import BytesIO
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.quiver import (
|
||||
QuiverImageObject,
|
||||
QuiverImageToSVGRequest,
|
||||
QuiverSVGResponse,
|
||||
QuiverTextToSVGRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_extras.nodes_images import SVG
|
||||
|
||||
|
||||
class QuiverTextToSVGNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="QuiverTextToSVGNode",
|
||||
display_name="Quiver Text to SVG",
|
||||
category="api node/image/Quiver",
|
||||
description="Generate an SVG from a text prompt using Quiver AI.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired SVG output.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"instructions",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Additional style or formatting guidance.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="ref_",
|
||||
min=0,
|
||||
max=4,
|
||||
),
|
||||
tooltip="Up to 4 reference images to guide the generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"arrow-preview",
|
||||
[
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Randomness control. Higher values increase randomness.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=1.0,
|
||||
min=0.05,
|
||||
max=1.0,
|
||||
step=0.05,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Nucleus sampling parameter.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"presence_penalty",
|
||||
default=0.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Token presence penalty.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for SVG generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.429}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
instructions: str = None,
|
||||
reference_images: IO.Autogrow.Type = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
|
||||
references = None
|
||||
if reference_images:
|
||||
references = []
|
||||
for key in reference_images:
|
||||
url = await upload_image_to_comfyapi(cls, reference_images[key])
|
||||
references.append(QuiverImageObject(url=url))
|
||||
if len(references) > 4:
|
||||
raise ValueError("Maximum 4 reference images are allowed.")
|
||||
|
||||
instructions_val = instructions.strip() if instructions else None
|
||||
if instructions_val == "":
|
||||
instructions_val = None
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
|
||||
response_model=QuiverSVGResponse,
|
||||
data=QuiverTextToSVGRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
instructions=instructions_val,
|
||||
references=references,
|
||||
temperature=model.get("temperature"),
|
||||
top_p=model.get("top_p"),
|
||||
presence_penalty=model.get("presence_penalty"),
|
||||
),
|
||||
)
|
||||
|
||||
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class QuiverImageToSVGNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="QuiverImageToSVGNode",
|
||||
display_name="Quiver Image to SVG",
|
||||
category="api node/image/Quiver",
|
||||
description="Vectorize a raster image into SVG using Quiver AI.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Input image to vectorize.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_crop",
|
||||
default=False,
|
||||
tooltip="Automatically crop to the dominant subject.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"arrow-preview",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"target_size",
|
||||
default=1024,
|
||||
min=128,
|
||||
max=4096,
|
||||
tooltip="Square resize target in pixels.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Randomness control. Higher values increase randomness.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=1.0,
|
||||
min=0.05,
|
||||
max=1.0,
|
||||
step=0.05,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Nucleus sampling parameter.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"presence_penalty",
|
||||
default=0.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Token presence penalty.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for SVG vectorization.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.429}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image,
|
||||
auto_crop: bool,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_url = await upload_image_to_comfyapi(cls, image)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
|
||||
response_model=QuiverSVGResponse,
|
||||
data=QuiverImageToSVGRequest(
|
||||
model=model["model"],
|
||||
image=QuiverImageObject(url=image_url),
|
||||
auto_crop=auto_crop if auto_crop else None,
|
||||
target_size=model.get("target_size"),
|
||||
temperature=model.get("temperature"),
|
||||
top_p=model.get("top_p"),
|
||||
presence_penalty=model.get("presence_penalty"),
|
||||
),
|
||||
)
|
||||
|
||||
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class QuiverExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
QuiverTextToSVGNode,
|
||||
QuiverImageToSVGNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> QuiverExtension:
|
||||
return QuiverExtension()
|
||||
@@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
widgets.upscale = "enabled" ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0457, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
$enabled := widgets.upscale = "enabled";
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$isFast
|
||||
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||
: $enabled ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
$enabled := widgets.upscale = "enabled";
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$isFast
|
||||
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||
: $enabled ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@@ -38,6 +38,7 @@ from comfy_api_nodes.util import (
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
"Starlight Precise 2.5": "slp-2.5",
|
||||
}
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user