mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-29 17:07:28 +00:00
Compare commits
27 Commits
feat/api-n
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a500f1edac | ||
|
|
3f77450ef1 | ||
|
|
fc1fdf3389 | ||
|
|
b353a7c863 | ||
|
|
3696c5bad6 | ||
|
|
3a56201da5 | ||
|
|
6a2cdb817d | ||
|
|
85b7495135 | ||
|
|
225c52f6a4 | ||
|
|
b1fdbeb9a7 | ||
|
|
1dc64f3526 | ||
|
|
359559c913 | ||
|
|
8165485a17 | ||
|
|
b0fd65e884 | ||
|
|
2a1f402601 | ||
|
|
3eba2dcf2d | ||
|
|
404d7b9978 | ||
|
|
6580a6bc01 | ||
|
|
3b15651bc6 | ||
|
|
a55835f10c | ||
|
|
b53b10ea61 | ||
|
|
7d5534d8e5 | ||
|
|
5ebb0c2e0b | ||
|
|
a0a64c679f | ||
|
|
8e73678dae | ||
|
|
c2862b24af | ||
|
|
f9ec85f739 |
@@ -1,6 +1,7 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
create_stub_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
reassign_asset_references,
|
||||
@@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
||||
UnenrichedReferenceRow,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_update_enrichment_level,
|
||||
count_active_siblings,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
convert_metadata_to_rows,
|
||||
@@ -80,6 +82,8 @@ __all__ = [
|
||||
"bulk_insert_references_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_enrichment_level",
|
||||
"count_active_siblings",
|
||||
"create_stub_asset",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"convert_metadata_to_rows",
|
||||
|
||||
@@ -78,6 +78,18 @@ def upsert_asset(
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def create_stub_asset(
|
||||
session: Session,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> Asset:
|
||||
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
|
||||
@@ -114,6 +114,23 @@ def get_reference_by_file_path(
|
||||
)
|
||||
|
||||
|
||||
def count_active_siblings(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
exclude_reference_id: str,
|
||||
) -> int:
|
||||
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||
return (
|
||||
session.query(AssetReference)
|
||||
.filter(
|
||||
AssetReference.asset_id == asset_id,
|
||||
AssetReference.id != exclude_reference_id,
|
||||
AssetReference.deleted_at.is_(None),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
def reference_exists_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
||||
delete_references_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_id,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
mark_references_missing_outside_prefixes,
|
||||
@@ -338,6 +339,7 @@ def build_asset_specs(
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
"job_id": None,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
@@ -426,6 +428,7 @@ def enrich_asset(
|
||||
except OSError:
|
||||
return new_level
|
||||
|
||||
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
@@ -489,6 +492,18 @@ def enrich_asset(
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||
# started (e.g. ingest_existing_file updated it), our results are
|
||||
# stale — discard them to avoid overwriting fresh registration data.
|
||||
ref = get_reference_by_id(session, reference_id)
|
||||
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||
session.rollback()
|
||||
logging.info(
|
||||
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||
reference_id,
|
||||
)
|
||||
return ENRICHMENT_STUB
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
@@ -77,7 +77,9 @@ class _AssetSeeder:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
# RLock is required because _run_scan() drains pending work while
|
||||
# holding _lock and re-enters start() which also acquires _lock.
|
||||
self._lock = threading.RLock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._last_progress: Progress | None = None
|
||||
@@ -92,6 +94,7 @@ class _AssetSeeder:
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
self._pending_enrich: dict | None = None
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
@@ -196,6 +199,42 @@ class _AssetSeeder:
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def enqueue_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||
|
||||
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||
request is stored and will run automatically when the current scan
|
||||
finishes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if started immediately, False if queued for later
|
||||
"""
|
||||
with self._lock:
|
||||
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||
return True
|
||||
if self._pending_enrich is not None:
|
||||
existing_roots = set(self._pending_enrich["roots"])
|
||||
existing_roots.update(roots)
|
||||
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||
self._pending_enrich["compute_hashes"] = (
|
||||
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||
)
|
||||
else:
|
||||
self._pending_enrich = {
|
||||
"roots": roots,
|
||||
"compute_hashes": compute_hashes,
|
||||
}
|
||||
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||
return False
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
@@ -381,9 +420,13 @@ class _AssetSeeder:
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
|
||||
def _reset_to_idle(self) -> None:
|
||||
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
@@ -594,9 +637,18 @@ class _AssetSeeder:
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
pending = self._pending_enrich
|
||||
if pending is not None:
|
||||
self._pending_enrich = None
|
||||
if not self.start_enrich(
|
||||
roots=pending["roots"],
|
||||
compute_hashes=pending["compute_hashes"],
|
||||
):
|
||||
logging.warning(
|
||||
"Pending enrich scan could not start (roots=%s)",
|
||||
pending["roots"],
|
||||
)
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
@@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
ingest_existing_file,
|
||||
register_output_files,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
@@ -72,6 +74,8 @@ __all__ = [
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"ingest_existing_file",
|
||||
"register_output_files",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
|
||||
@@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
job_id: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
@@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
|
||||
name: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
job_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
@@ -167,6 +169,7 @@ def batch_insert_seed_assets(
|
||||
"name": spec["info_name"],
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"job_id": spec.get("job_id"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
|
||||
@@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
count_active_siblings,
|
||||
create_stub_asset,
|
||||
ensure_tags_exist,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_file_path,
|
||||
@@ -23,7 +26,8 @@ from app.assets.database.queries import (
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.helpers import get_utc_now, normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
@@ -130,6 +134,102 @@ def _ingest_file_from_path(
|
||||
)
|
||||
|
||||
|
||||
def register_output_files(
|
||||
file_paths: Sequence[str],
|
||||
user_metadata: UserMetadata = None,
|
||||
job_id: str | None = None,
|
||||
) -> int:
|
||||
"""Register a batch of output file paths as assets.
|
||||
|
||||
Returns the number of files successfully registered.
|
||||
"""
|
||||
registered = 0
|
||||
for abs_path in file_paths:
|
||||
if not os.path.isfile(abs_path):
|
||||
continue
|
||||
try:
|
||||
if ingest_existing_file(
|
||||
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||
):
|
||||
registered += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to register output: %s", abs_path)
|
||||
return registered
|
||||
|
||||
|
||||
def ingest_existing_file(
|
||||
abs_path: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
extra_tags: Sequence[str] = (),
|
||||
owner_id: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Register an existing on-disk file as an asset stub.
|
||||
|
||||
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||
|
||||
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||
UX visibility.
|
||||
|
||||
Returns True if a row was inserted or updated, False otherwise.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||
|
||||
with create_session() as session:
|
||||
existing_ref = get_reference_by_file_path(session, locator)
|
||||
if existing_ref is not None:
|
||||
now = get_utc_now()
|
||||
existing_ref.mtime_ns = mtime_ns
|
||||
existing_ref.job_id = job_id
|
||||
existing_ref.is_missing = False
|
||||
existing_ref.deleted_at = None
|
||||
existing_ref.updated_at = now
|
||||
existing_ref.enrichment_level = 0
|
||||
|
||||
asset = existing_ref.asset
|
||||
if asset:
|
||||
# If other refs share this asset, detach to a new stub
|
||||
# instead of mutating the shared row.
|
||||
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||
if siblings > 0:
|
||||
new_asset = create_stub_asset(
|
||||
session,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type or asset.mime_type,
|
||||
)
|
||||
existing_ref.asset_id = new_asset.id
|
||||
else:
|
||||
asset.hash = None
|
||||
asset.size_bytes = size_bytes
|
||||
if mime_type:
|
||||
asset.mime_type = mime_type
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
spec = {
|
||||
"abs_path": abs_path,
|
||||
"size_bytes": size_bytes,
|
||||
"mtime_ns": mtime_ns,
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": os.path.basename(abs_path),
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": mime_type,
|
||||
"job_id": job_id,
|
||||
}
|
||||
if tags:
|
||||
ensure_tags_exist(session, tags)
|
||||
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||
session.commit()
|
||||
return result.won_paths > 0
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
|
||||
@@ -93,12 +93,13 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "models"], str]:
|
||||
) -> tuple[Literal["input", "output", "temp", "models"], str]:
|
||||
"""Determine which root category a file path belongs to.
|
||||
|
||||
Categories:
|
||||
- 'input': under folder_paths.get_input_directory()
|
||||
- 'output': under folder_paths.get_output_directory()
|
||||
- 'temp': under folder_paths.get_temp_directory()
|
||||
- 'models': under any base path from get_comfy_models_folders()
|
||||
|
||||
Returns:
|
||||
@@ -129,7 +130,12 @@ def get_asset_category_and_relative_path(
|
||||
if _check_is_within(fp_abs, output_base):
|
||||
return "output", _compute_relative(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
# 3) temp
|
||||
temp_base = os.path.abspath(folder_paths.get_temp_directory())
|
||||
if _check_is_within(fp_abs, temp_base):
|
||||
return "temp", _compute_relative(fp_abs, temp_base)
|
||||
|
||||
# 4) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
@@ -146,7 +152,7 @@ def get_asset_category_and_relative_path(
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
||||
f"Path is not within input, output, temp, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
@@ -0,0 +1,90 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
uniform float u_float5;
|
||||
uniform float u_float6;
|
||||
uniform float u_float7;
|
||||
uniform float u_float8;
|
||||
uniform bool u_bool0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(c.r, max(c.g, c.b));
|
||||
float minC = min(c.r, min(c.g, c.b));
|
||||
float l = (maxC + minC) * 0.5;
|
||||
if (maxC == minC) return vec3(0.0, 0.0, l);
|
||||
float d = maxC - minC;
|
||||
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
|
||||
float h;
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / d + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / d + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
if (t < 0.0) t += 1.0;
|
||||
if (t > 1.0) t -= 1.0;
|
||||
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 1.0 / 2.0) return q;
|
||||
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
float h = hsl.x, s = hsl.y, l = hsl.z;
|
||||
if (s == 0.0) return vec3(l);
|
||||
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
|
||||
float p = 2.0 * l - q;
|
||||
return vec3(
|
||||
hue2rgb(p, q, h + 1.0 / 3.0),
|
||||
hue2rgb(p, q, h),
|
||||
hue2rgb(p, q, h - 1.0 / 3.0)
|
||||
);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
|
||||
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
|
||||
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
|
||||
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float lightness = (maxC + minC) * 0.5;
|
||||
|
||||
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
|
||||
const float a = 0.25;
|
||||
const float b = 0.333;
|
||||
const float scale = 0.7;
|
||||
|
||||
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
|
||||
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
|
||||
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
|
||||
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
|
||||
|
||||
color += sw * shadows + mw * midtones + hw * highlights;
|
||||
|
||||
if (u_bool0) {
|
||||
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
|
||||
hsl.z = lightness;
|
||||
color = hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
49
blueprints/.glsl/Color_Curves_8.frag
Normal file
49
blueprints/.glsl/Color_Curves_8.frag
Normal file
@@ -0,0 +1,49 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
|
||||
uniform sampler2D u_curve1; // Red channel curve
|
||||
uniform sampler2D u_curve2; // Green channel curve
|
||||
uniform sampler2D u_curve3; // Blue channel curve
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// GIMP-compatible curve lookup with manual linear interpolation.
|
||||
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
|
||||
// index = value * (n_samples - 1)
|
||||
// f = fract(index)
|
||||
// result = (1-f) * samples[floor] + f * samples[ceil]
|
||||
//
|
||||
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
|
||||
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
|
||||
float applyCurve(sampler2D curve, float value) {
|
||||
value = clamp(value, 0.0, 1.0);
|
||||
|
||||
float pos = value * 255.0;
|
||||
int lo = int(floor(pos));
|
||||
int hi = min(lo + 1, 255);
|
||||
float f = pos - float(lo);
|
||||
|
||||
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
|
||||
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
|
||||
|
||||
return a + f * (b - a);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// GIMP order: per-channel curves first, then RGB master curve.
|
||||
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
|
||||
// dest = colors_curve( channel_curve( src ) )
|
||||
float tmp_r = applyCurve(u_curve1, color.r);
|
||||
float tmp_g = applyCurve(u_curve2, color.g);
|
||||
float tmp_b = applyCurve(u_curve3, color.b);
|
||||
color.r = applyCurve(u_curve0, tmp_r);
|
||||
color.g = applyCurve(u_curve0, tmp_g);
|
||||
color.b = applyCurve(u_curve0, tmp_b);
|
||||
|
||||
fragColor0 = vec4(color.rgb, color.a);
|
||||
}
|
||||
1
blueprints/Color Balance.json
Normal file
1
blueprints/Color Balance.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Curves.json
Normal file
1
blueprints/Color Curves.json
Normal file
File diff suppressed because one or more lines are too long
@@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
CACHE_RAM_AUTO_GB = -1.0
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
|
||||
725
comfy/ldm/rt_detr/rtdetr_v4.py
Normal file
725
comfy/ldm/rt_detr/rtdetr_v4.py
Normal file
@@ -0,0 +1,725 @@
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
COCO_CLASSES = [
|
||||
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
|
||||
'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat',
|
||||
'dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack',
|
||||
'umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball',
|
||||
'kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket',
|
||||
'bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple',
|
||||
'sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair',
|
||||
'couch','potted plant','bed','dining table','toilet','tv','laptop','mouse',
|
||||
'remote','keyboard','cell phone','microwave','oven','toaster','sink',
|
||||
'refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush',
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HGNetv2 backbone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConvBNAct(nn.Module):
|
||||
"""Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem)."""
|
||||
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
|
||||
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||
self.act = nn.ReLU() if use_act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
class LightConvBNAct(nn.Module):
|
||||
def __init__(self, ic, oc, k, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations)
|
||||
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.conv1(x))
|
||||
|
||||
class _StemBlock(nn.Module):
|
||||
def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations)
|
||||
# stem2a/stem2b use kernel=2, stride=1, no internal padding;
|
||||
# padding is applied manually in forward (matching PaddlePaddle original)
|
||||
self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.pool = nn.MaxPool2d(2, 1, ceil_mode=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem1(x)
|
||||
x = F.pad(x, (0, 1, 0, 1)) # pad before pool and stem2a
|
||||
x2 = self.stem2a(x)
|
||||
x2 = F.pad(x2, (0, 1, 0, 1)) # pad before stem2b
|
||||
x2 = self.stem2b(x2)
|
||||
x1 = self.pool(x)
|
||||
return self.stem4(self.stem3(torch.cat([x1, x2], 1)))
|
||||
|
||||
|
||||
class _HG_Block(nn.Module):
|
||||
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.residual = residual
|
||||
if light:
|
||||
self.layers = nn.ModuleList(
|
||||
[LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
|
||||
total = ic + layer_num * mc
|
||||
|
||||
self.aggregation = nn.Sequential(
|
||||
ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations),
|
||||
ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations))
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
outs = [x]
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outs.append(x)
|
||||
x = self.aggregation(torch.cat(outs, 1))
|
||||
return x + identity if self.residual else x
|
||||
|
||||
|
||||
class _HG_Stage(nn.Module):
|
||||
# config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num
|
||||
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
if downsample:
|
||||
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
self.blocks = nn.Sequential(*[
|
||||
_HG_Block(ic if i == 0 else oc, mc, oc, layer_num,
|
||||
k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations)
|
||||
for i in range(num_blocks)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
return self.blocks(self.downsample(x))
|
||||
|
||||
|
||||
class HGNetv2(nn.Module):
|
||||
# B5 config: stem=[3,32,64], stages=[ic, mc, oc, blocks, down, light, k, layers]
|
||||
_STAGE_CFGS = [[64, 64, 128, 1, False, False, 3, 6],
|
||||
[128, 128, 512, 2, True, False, 3, 6],
|
||||
[512, 256, 1024, 5, True, True, 5, 6],
|
||||
[1024,512, 2048, 2, True, True, 5, 6]]
|
||||
|
||||
def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations)
|
||||
self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS])
|
||||
self.return_idx = list(return_idx)
|
||||
self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.stem(x)
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = stage(x)
|
||||
if i in self.return_idx:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder — HybridEncoder (dfine version: RepNCSPELAN4 + SCDown PAN)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConvNormLayer(nn.Module):
|
||||
"""Conv→act (expects pre-fused BN weights)."""
|
||||
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
p = (k - 1) // 2 if padding is None else padding
|
||||
self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype)
|
||||
self.act = nn.SiLU() if act == 'silu' else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
|
||||
class VGGBlock(nn.Module):
|
||||
"""Rep-VGG block (expects pre-fused weights)."""
|
||||
def __init__(self, ic, oc, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
|
||||
class CSPLayer(nn.Module):
|
||||
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
h = int(oc * expansion)
|
||||
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)])
|
||||
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x))
|
||||
|
||||
|
||||
class RepNCSPELAN4(nn.Module):
|
||||
"""CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder."""
|
||||
def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.c = c3 // 2
|
||||
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
|
||||
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
|
||||
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
y = list(self.cv1(x).split((self.c, self.c), 1))
|
||||
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
|
||||
return self.cv4(torch.cat(y, 1))
|
||||
|
||||
|
||||
class SCDown(nn.Module):
|
||||
"""Separable conv downsampling used in HybridEncoder PAN bottom-up path."""
|
||||
def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations)
|
||||
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, query, key, value, attn_mask=None):
|
||||
optimized_attention = optimized_attention_for_device(query.device, False, small_input=True)
|
||||
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
|
||||
out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
class _TransformerEncoderLayer(nn.Module):
|
||||
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
|
||||
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
|
||||
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
|
||||
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.activation = nn.GELU()
|
||||
|
||||
def forward(self, src, src_mask=None, pos_embed=None):
|
||||
q = k = src if pos_embed is None else src + pos_embed
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)
|
||||
src = self.norm1(src + src2)
|
||||
src2 = self.linear2(self.activation(self.linear1(src)))
|
||||
return self.norm2(src + src2)
|
||||
|
||||
|
||||
class _TransformerEncoder(nn.Module):
|
||||
"""Thin wrapper so state-dict keys are encoder.0.layers.N.*"""
|
||||
def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
_TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, src, src_mask=None, pos_embed=None):
|
||||
for layer in self.layers:
|
||||
src = layer(src, src_mask=src_mask, pos_embed=pos_embed)
|
||||
return src
|
||||
|
||||
|
||||
class HybridEncoder(nn.Module):
|
||||
def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1,
|
||||
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_channels = list(in_channels)
|
||||
self.feat_strides = list(feat_strides)
|
||||
self.hidden_dim = hidden_dim
|
||||
self.use_encoder_idx = list(use_encoder_idx)
|
||||
self.pe_temperature = pe_temperature
|
||||
self.eval_spatial_size = eval_spatial_size
|
||||
self.out_channels = [hidden_dim] * len(in_channels)
|
||||
self.out_strides = list(feat_strides)
|
||||
|
||||
# channel projection (expects pre-fused weights)
|
||||
self.input_proj = nn.ModuleList([
|
||||
nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))
|
||||
for ch in in_channels
|
||||
])
|
||||
|
||||
# AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.*
|
||||
self.encoder = nn.ModuleList([
|
||||
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(use_encoder_idx))
|
||||
])
|
||||
|
||||
nb = round(3 * depth_mult)
|
||||
exp = expansion
|
||||
|
||||
# top-down FPN (dfine: lateral conv has no act)
|
||||
self.lateral_convs = nn.ModuleList(
|
||||
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
self.fpn_blocks = nn.ModuleList(
|
||||
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
|
||||
# bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2)
|
||||
self.downsample_convs = nn.ModuleList(
|
||||
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations))
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
self.pan_blocks = nn.ModuleList(
|
||||
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(len(in_channels) - 1)])
|
||||
|
||||
# cache positional embeddings for fixed spatial size
|
||||
if eval_spatial_size:
|
||||
for idx in self.use_encoder_idx:
|
||||
stride = self.feat_strides[idx]
|
||||
pe = self._build_pe(eval_spatial_size[1] // stride,
|
||||
eval_spatial_size[0] // stride,
|
||||
hidden_dim, pe_temperature)
|
||||
setattr(self, f'pos_embed{idx}', pe)
|
||||
|
||||
@staticmethod
|
||||
def _build_pe(w, h, dim=256, temp=10000.):
|
||||
assert dim % 4 == 0
|
||||
gw = torch.arange(w, dtype=torch.float32)
|
||||
gh = torch.arange(h, dtype=torch.float32)
|
||||
gw, gh = torch.meshgrid(gw, gh, indexing='ij')
|
||||
pdim = dim // 4
|
||||
omega = 1. / (temp ** (torch.arange(pdim, dtype=torch.float32) / pdim))
|
||||
ow = gw.flatten()[:, None] @ omega[None]
|
||||
oh = gh.flatten()[:, None] @ omega[None]
|
||||
return torch.cat([ow.sin(), ow.cos(), oh.sin(), oh.cos()], 1)[None]
|
||||
|
||||
def forward(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
|
||||
|
||||
for i, enc_idx in enumerate(self.use_encoder_idx):
|
||||
h, w = proj[enc_idx].shape[2:]
|
||||
src = proj[enc_idx].flatten(2).permute(0, 2, 1)
|
||||
pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype)
|
||||
for layer in self.encoder[i].layers:
|
||||
src = layer(src, pos_embed=pe)
|
||||
proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
|
||||
|
||||
n = len(self.in_channels)
|
||||
inner = [proj[-1]]
|
||||
for k in range(n - 1, 0, -1):
|
||||
j = n - 1 - k
|
||||
top = self.lateral_convs[j](inner[0])
|
||||
inner[0] = top
|
||||
up = F.interpolate(top, scale_factor=2., mode='nearest')
|
||||
inner.insert(0, self.fpn_blocks[j](torch.cat([up, proj[k - 1]], 1)))
|
||||
|
||||
outs = [inner[0]]
|
||||
for k in range(n - 1):
|
||||
outs.append(self.pan_blocks[k](
|
||||
torch.cat([self.downsample_convs[k](outs[-1]), inner[k + 1]], 1)))
|
||||
return outs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder — DFINETransformer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
value : list of per-level tensors [bs*n_head, c, h_l, w_l]
|
||||
sampling_locations: [bs, Lq, n_head, sum(pts), 2] in [0,1]
|
||||
attention_weights : [bs, Lq, n_head, sum(pts)]
|
||||
"""
|
||||
_, c = value[0].shape[:2] # bs*n_head, c
|
||||
_, Lq, n_head, _, _ = sampling_locations.shape
|
||||
bs = sampling_locations.shape[0]
|
||||
n_h = n_head
|
||||
|
||||
grids = (2 * sampling_locations - 1) # [bs, Lq, n_head, sum_pts, 2]
|
||||
grids = grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # [bs*n_head, Lq, sum_pts, 2]
|
||||
grids_per_lvl = grids.split(num_points_list, dim=2) # list of [bs*n_head, Lq, pts_l, 2]
|
||||
|
||||
sampled = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
val_l = value[lvl].reshape(bs * n_h, c, h, w)
|
||||
sv = F.grid_sample(val_l, grids_per_lvl[lvl], mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
sampled.append(sv) # sv: [bs*n_head, c, Lq, pts_l]
|
||||
|
||||
attn = attention_weights.permute(0, 2, 1, 3) # [bs, n_head, Lq, sum_pts]
|
||||
attn = attn.flatten(0, 1).unsqueeze(1) # [bs*n_head, 1, Lq, sum_pts]
|
||||
out = (torch.cat(sampled, -1) * attn).sum(-1) # [bs*n_head, c, Lq]
|
||||
out = out.reshape(bs, n_h * c, Lq)
|
||||
return out.permute(0, 2, 1) # [bs, Lq, hidden]
|
||||
|
||||
|
||||
class MSDeformableAttention(nn.Module):
|
||||
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.embed_dim, self.num_heads = embed_dim, num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
pts = num_points if isinstance(num_points, list) else [num_points] * num_levels
|
||||
self.num_points_list = pts
|
||||
self.offset_scale = offset_scale
|
||||
total = num_heads * sum(pts)
|
||||
self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32))
|
||||
self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype)
|
||||
self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, query, ref_pts, value, spatial_shapes):
|
||||
bs, Lq = query.shape[:2]
|
||||
offsets = self.sampling_offsets(query).reshape(
|
||||
bs, Lq, self.num_heads, sum(self.num_points_list), 2)
|
||||
attn_w = F.softmax(
|
||||
self.attention_weights(query).reshape(
|
||||
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
|
||||
scale = self.num_points_scale.to(query).unsqueeze(-1)
|
||||
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
|
||||
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
|
||||
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
|
||||
|
||||
|
||||
class Gate(nn.Module):
|
||||
def __init__(self, d_model, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1)
|
||||
return self.norm(g1 * x1 + g2 * x2)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
|
||||
self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = nn.SiLU()(layer(x)) if i < len(self.layers) - 1 else layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
|
||||
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
|
||||
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
|
||||
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU()
|
||||
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
|
||||
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
|
||||
q = k = target if query_pos is None else target + query_pos
|
||||
t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask)
|
||||
target = self.norm1(target + t2)
|
||||
t2 = self.cross_attn(
|
||||
target if query_pos is None else target + query_pos,
|
||||
ref_pts, value, spatial_shapes)
|
||||
target = self.gateway(target, t2)
|
||||
t2 = self.linear2(self.activation(self.linear1(target)))
|
||||
target = self.norm3((target + t2).clamp(-65504, 65504))
|
||||
return target
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FDR utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def weighting_function(reg_max, up, reg_scale):
|
||||
"""Non-uniform weighting function W(n) for FDR box regression."""
|
||||
ub1 = (abs(up[0]) * abs(reg_scale)).item()
|
||||
ub2 = ub1 * 2
|
||||
step = (ub1 + 1) ** (2 / (reg_max - 2))
|
||||
left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
|
||||
right = [ (step ** i) - 1 for i in range(1, reg_max // 2)]
|
||||
vals = [-ub2] + left + [0] + right + [ub2]
|
||||
return torch.tensor(vals, dtype=up.dtype, device=up.device)
|
||||
|
||||
|
||||
def distance2bbox(points, distance, reg_scale):
|
||||
"""Decode edge-distances → cxcywh boxes."""
|
||||
rs = abs(reg_scale).to(dtype=points.dtype)
|
||||
x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs)
|
||||
y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs)
|
||||
x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs)
|
||||
y2 = points[..., 1] + (0.5 * rs + distance[..., 3]) * (points[..., 3] / rs)
|
||||
x0, y0, x1_, y1_ = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
|
||||
return torch.stack([x0, y0, x1_, y1_], -1)
|
||||
|
||||
|
||||
class Integral(nn.Module):
|
||||
"""Sum Pr(n)·W(n) over the distribution bins."""
|
||||
def __init__(self, reg_max=32):
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
|
||||
def forward(self, x, project):
|
||||
shape = x.shape
|
||||
x = F.softmax(x.reshape(-1, self.reg_max + 1), 1)
|
||||
x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4)
|
||||
return x.reshape(list(shape[:-1]) + [-1])
|
||||
|
||||
|
||||
class LQE(nn.Module):
|
||||
"""Location Quality Estimator — refines class scores using corner distribution."""
|
||||
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.k, self.reg_max = k, reg_max
|
||||
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, scores, pred_corners):
|
||||
B, L, _ = pred_corners.shape
|
||||
prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), -1)
|
||||
topk, _ = prob.topk(self.k, -1)
|
||||
stat = torch.cat([topk, topk.mean(-1, keepdim=True)], -1)
|
||||
return scores + self.reg_conf(stat.reshape(B, L, -1))
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.nhead = nhead
|
||||
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||
self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.eval_idx + 1)
|
||||
])
|
||||
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)])
|
||||
self.register_buffer('project', weighting_function(reg_max, up, reg_scale))
|
||||
|
||||
def _value_op(self, memory, spatial_shapes):
|
||||
"""Reshape memory to per-level value tensors for deformable attention."""
|
||||
c = self.hidden_dim // self.nhead
|
||||
split = [h * w for h, w in spatial_shapes]
|
||||
val = memory.reshape(memory.shape[0], memory.shape[1], self.nhead, c) # memory: [bs, sum(h*w), hidden_dim]
|
||||
# → [bs, n_head, c, sum_hw]
|
||||
val = val.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, sum_hw]
|
||||
return val.split(split, dim=-1) # list of [bs*n_head, c, h_l*w_l]
|
||||
|
||||
def forward(self, target, ref_pts_unact, memory, spatial_shapes, bbox_head, score_head, query_pos_head, pre_bbox_head, integral):
|
||||
val_split_flat = self._value_op(memory, spatial_shapes) # pre-split value for deformable attention
|
||||
|
||||
# reshape to [bs*n_head, c, h_l, w_l]
|
||||
value = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
v = val_split_flat[lvl] # [bs*n_head, c, h*w]
|
||||
value.append(v.reshape(v.shape[0], v.shape[1], h, w))
|
||||
|
||||
ref_pts = F.sigmoid(ref_pts_unact)
|
||||
output = target
|
||||
output_detach = pred_corners_undetach = 0
|
||||
|
||||
dec_bboxes, dec_logits = [], []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
ref_input = ref_pts.unsqueeze(2) # [bs, Lq, 1, 4]
|
||||
query_pos = query_pos_head(ref_pts).clamp(-10, 10)
|
||||
output = layer(output, ref_input, value, spatial_shapes, query_pos=query_pos)
|
||||
|
||||
if i == 0:
|
||||
ref_unact = ref_pts.clamp(1e-5, 1 - 1e-5)
|
||||
ref_unact = torch.log(ref_unact / (1 - ref_unact))
|
||||
pre_bboxes = F.sigmoid(pre_bbox_head(output) + ref_unact)
|
||||
ref_pts_initial = pre_bboxes.detach()
|
||||
|
||||
pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach
|
||||
inter_ref_bbox = distance2bbox(ref_pts_initial, integral(pred_corners, self.project), self.reg_scale)
|
||||
|
||||
if i == self.eval_idx:
|
||||
scores = score_head[i](output)
|
||||
scores = self.lqe_layers[i](scores, pred_corners)
|
||||
dec_bboxes.append(inter_ref_bbox)
|
||||
dec_logits.append(scores)
|
||||
break
|
||||
|
||||
pred_corners_undetach = pred_corners
|
||||
ref_pts = inter_ref_bbox.detach()
|
||||
output_detach = output.detach()
|
||||
|
||||
return torch.stack(dec_bboxes), torch.stack(dec_logits)
|
||||
|
||||
|
||||
class DFINETransformer(nn.Module):
|
||||
def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32],
|
||||
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32,
|
||||
reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert len(feat_strides) == len(feat_channels)
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_queries = num_queries
|
||||
self.num_levels = num_levels
|
||||
self.eps = eps
|
||||
self.eval_spatial_size = eval_spatial_size
|
||||
|
||||
self.feat_strides = list(feat_strides)
|
||||
for i in range(num_levels - len(feat_strides)):
|
||||
self.feat_strides.append(feat_strides[-1] * 2 ** (i + 1))
|
||||
|
||||
# input projection (expects pre-fused weights)
|
||||
self.input_proj = nn.ModuleList()
|
||||
for ch in feat_channels:
|
||||
if ch == hidden_dim:
|
||||
self.input_proj.append(nn.Identity())
|
||||
else:
|
||||
self.input_proj.append(nn.Sequential(OrderedDict([
|
||||
('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])))
|
||||
in_ch = feat_channels[-1]
|
||||
for i in range(num_levels - len(feat_channels)):
|
||||
self.input_proj.append(nn.Sequential(OrderedDict([
|
||||
('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim,
|
||||
hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))])))
|
||||
in_ch = hidden_dim
|
||||
|
||||
# FDR parameters (non-trainable placeholders, set from config)
|
||||
self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
|
||||
self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
|
||||
|
||||
pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels
|
||||
self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts,
|
||||
num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations)
|
||||
self.enc_output = nn.Sequential(OrderedDict([
|
||||
('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)),
|
||||
('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))]))
|
||||
self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype)
|
||||
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||
self.dec_score_head = nn.ModuleList(
|
||||
[operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)])
|
||||
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||
self.dec_bbox_head = nn.ModuleList(
|
||||
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.eval_idx_ + 1)])
|
||||
self.integral = Integral(reg_max)
|
||||
|
||||
if eval_spatial_size:
|
||||
# Register as buffers so checkpoint values override the freshly-computed defaults
|
||||
anchors, valid_mask = self._gen_anchors()
|
||||
self.register_buffer('anchors', anchors)
|
||||
self.register_buffer('valid_mask', valid_mask)
|
||||
|
||||
def _gen_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device='cpu'):
|
||||
if spatial_shapes is None:
|
||||
h0, w0 = self.eval_spatial_size
|
||||
spatial_shapes = [[int(h0 / s), int(w0 / s)] for s in self.feat_strides]
|
||||
anchors = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
|
||||
gxy = (torch.stack([gx, gy], -1).float() + 0.5) / torch.tensor([w, h], dtype=dtype)
|
||||
wh = torch.ones_like(gxy) * grid_size * (2. ** lvl)
|
||||
anchors.append(torch.cat([gxy, wh], -1).reshape(-1, h * w, 4))
|
||||
anchors = torch.cat(anchors, 1).to(device)
|
||||
valid_mask = ((anchors > self.eps) & (anchors < 1 - self.eps)).all(-1, keepdim=True)
|
||||
anchors = torch.log(anchors / (1 - anchors))
|
||||
anchors = torch.where(valid_mask, anchors, torch.full_like(anchors, float('inf')))
|
||||
return anchors, valid_mask
|
||||
|
||||
def _encoder_input(self, feats: List[torch.Tensor]):
|
||||
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
|
||||
for i in range(len(feats), self.num_levels):
|
||||
proj.append(self.input_proj[i](feats[-1] if i == len(feats) else proj[-1]))
|
||||
flat, shapes = [], []
|
||||
for f in proj:
|
||||
_, _, h, w = f.shape
|
||||
flat.append(f.flatten(2).permute(0, 2, 1))
|
||||
shapes.append([h, w])
|
||||
return torch.cat(flat, 1), shapes
|
||||
|
||||
def _decoder_input(self, memory: torch.Tensor):
|
||||
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
|
||||
if memory.shape[0] > 1:
|
||||
anchors = anchors.repeat(memory.shape[0], 1, 1)
|
||||
|
||||
mem = valid_mask.to(memory) * memory
|
||||
out_mem = self.enc_output(mem)
|
||||
logits = self.enc_score_head(out_mem)
|
||||
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
|
||||
idx_e = idx.unsqueeze(-1)
|
||||
topk_mem = out_mem.gather(1, idx_e.expand(-1, -1, out_mem.shape[-1]))
|
||||
topk_anc = anchors.gather(1, idx_e.expand(-1, -1, anchors.shape[-1]))
|
||||
topk_ref = self.enc_bbox_head(topk_mem) + topk_anc
|
||||
return topk_mem.detach(), topk_ref.detach()
|
||||
|
||||
def forward(self, feats: List[torch.Tensor]):
|
||||
memory, shapes = self._encoder_input(feats)
|
||||
content, ref = self._decoder_input(memory)
|
||||
out_bboxes, out_logits = self.decoder(
|
||||
content, ref, memory, shapes,
|
||||
self.dec_bbox_head, self.dec_score_head,
|
||||
self.query_pos_head, self.pre_bbox_head, self.integral)
|
||||
return {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RTv4(nn.Module):
|
||||
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.operations = operations
|
||||
|
||||
self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations)
|
||||
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations)
|
||||
self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries,
|
||||
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_queries = num_queries
|
||||
self.load_device = comfy.model_management.get_torch_device()
|
||||
|
||||
def _forward(self, x: torch.Tensor):
|
||||
return self.decoder(self.encoder(self.backbone(x)))
|
||||
|
||||
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
|
||||
logits = outputs['pred_logits']
|
||||
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
|
||||
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
|
||||
scores = F.sigmoid(logits)
|
||||
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
|
||||
labels = idx % self.num_classes
|
||||
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
|
||||
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
|
||||
|
||||
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
|
||||
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
|
||||
return self.postprocess(outputs, orig_size)
|
||||
@@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
|
||||
return dest_views
|
||||
|
||||
aimdo_enabled = False
|
||||
|
||||
extra_ram_release_callback = None
|
||||
RAM_CACHE_HEADROOM = 0
|
||||
|
||||
def set_ram_cache_release_state(callback, headroom):
|
||||
global extra_ram_release_callback
|
||||
global RAM_CACHE_HEADROOM
|
||||
extra_ram_release_callback = callback
|
||||
RAM_CACHE_HEADROOM = max(0, int(headroom))
|
||||
|
||||
def extra_ram_release(target):
|
||||
if extra_ram_release_callback is None:
|
||||
return 0
|
||||
return extra_ram_release_callback(target)
|
||||
|
||||
@@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@@ -890,7 +891,7 @@ class Flux(BaseModel):
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
return kwargs.get("pooled_output", None)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@@ -1957,3 +1958,7 @@ class Kandinsky5Image(Kandinsky5):
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
@@ -698,6 +698,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["audio_model"] = "ace1.5"
|
||||
return dit_config
|
||||
|
||||
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "RT_DETR_v4"
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ total_vram = 0
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
training_fp8_bwd = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
@@ -668,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if device is None or shift_model.device == device:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
@@ -678,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
if not DISABLE_SMART_MEMORY or device is None:
|
||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
@@ -707,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
else:
|
||||
elif device is not None:
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
|
||||
@@ -300,9 +300,6 @@ class ModelPatcher:
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
|
||||
69
comfy/ops.py
69
comfy/ops.py
@@ -777,8 +777,16 @@ from .quant_ops import (
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||
|
||||
When training_fp8_bwd is enabled:
|
||||
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||
- Cached input is FP8 (half the memory of bf16)
|
||||
|
||||
When training_fp8_bwd is disabled:
|
||||
- Forward: quantize input per layout, use quantized matmul
|
||||
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input (same as inference path)
|
||||
# Quantize input for forward (same layout as weight)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
@@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Restore original input shape
|
||||
# Unflatten output to match original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
# Save for backward
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||
|
||||
if ctx.fp8_bwd:
|
||||
# Cache FP8 quantized input — half the memory of bf16
|
||||
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||
ctx.q_input = q_input # already FP8, reuse
|
||||
else:
|
||||
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||
ctx.save_for_backward(weight)
|
||||
else:
|
||||
ctx.q_input = None
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
# Value casting — only difference between fp8 and non-fp8 paths
|
||||
if ctx.fp8_bwd:
|
||||
weight, = ctx.saved_tensors
|
||||
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||
weight_mm = weight
|
||||
elif isinstance(weight, QuantizedTensor):
|
||||
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
else:
|
||||
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
input_mm = ctx.q_input
|
||||
else:
|
||||
weight_f = weight.to(compute_dtype)
|
||||
input_float, weight = ctx.saved_tensors
|
||||
# Standard tensors → torch.mm does regular matmul
|
||||
grad_mm = grad_2d
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_mm = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_mm = weight.to(compute_dtype)
|
||||
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
# Computation — same for both paths, dispatch handles the rest
|
||||
grad_input = torch.mm(grad_mm, weight_mm)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
@@ -895,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
self.weight = None
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
@@ -1001,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
|
||||
@@ -2,6 +2,7 @@ import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
import psutil
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@@ -12,6 +13,11 @@ def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
|
||||
#we split the difference and assume half the RAM cache headroom is for us
|
||||
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
|
||||
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
|
||||
46
comfy/sd.py
46
comfy/sd.py
@@ -61,6 +61,7 @@ import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -279,9 +280,6 @@ class CLIP:
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.patcher.get_ram_usage()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
@@ -425,13 +423,13 @@ class CLIP:
|
||||
def get_key_patches(self):
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@@ -839,9 +837,6 @@ class VAE:
|
||||
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
@@ -1228,6 +1223,11 @@ class TEModel(Enum):
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
GEMMA_3_4B_VISION = 22
|
||||
QWEN35_08B = 23
|
||||
QWEN35_2B = 24
|
||||
QWEN35_4B = 25
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1267,6 +1267,17 @@ def detect_te_model(sd):
|
||||
return TEModel.QWEN25_3B
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
|
||||
weight = sd['model.language_model.layers.0.input_layernorm.weight']
|
||||
if weight.shape[0] == 1024:
|
||||
return TEModel.QWEN35_08B
|
||||
if weight.shape[0] == 2560:
|
||||
return TEModel.QWEN35_4B
|
||||
if weight.shape[0] == 4096:
|
||||
return TEModel.QWEN35_9B
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@@ -1299,11 +1310,12 @@ def t5xxl_detect(clip_data):
|
||||
return {}
|
||||
|
||||
def llama_detect(clip_data):
|
||||
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
||||
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
|
||||
|
||||
for sd in clip_data:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||
for weight_name in weight_names:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||
|
||||
return {}
|
||||
|
||||
@@ -1431,6 +1443,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
@@ -1719,15 +1736,16 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
"""
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
|
||||
@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
||||
if isinstance(tokens, dict):
|
||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||
else:
|
||||
tokens_only = tokens
|
||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
@@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
@@ -1734,6 +1734,21 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
class RT_DETR_v4(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "RT_DETR_v4",
|
||||
}
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.RT_DETR_v4(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
lm_head: bool = True
|
||||
stop_tokens = [151643, 151645]
|
||||
|
||||
@dataclass
|
||||
@@ -655,6 +655,17 @@ class Llama2_(nn.Module):
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
return past_key_values[0][2]
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
return precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
@@ -667,17 +678,12 @@ class Llama2_(nn.Module):
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
past_len = past_key_values[0][2]
|
||||
past_len = self.get_past_len(past_key_values)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=x.device)
|
||||
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
@@ -812,9 +818,16 @@ class BaseGenerate:
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
|
||||
device = embeds.device
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
model_config = self.model.config
|
||||
past_key_values = []
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
stop_tokens = self.model.config.stop_tokens
|
||||
@@ -829,11 +842,8 @@ class BaseGenerate:
|
||||
if embeds.ndim == 2:
|
||||
embeds = embeds.unsqueeze(0)
|
||||
|
||||
past_key_values = [] #kv_cache init
|
||||
max_cache_len = embeds.shape[1] + max_length
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||
|
||||
@@ -844,7 +854,7 @@ class BaseGenerate:
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
@@ -856,7 +866,7 @@ class BaseGenerate:
|
||||
|
||||
return generated_token_ids
|
||||
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
|
||||
|
||||
if not do_sample or temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
@@ -867,6 +877,11 @@ class BaseGenerate:
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if presence_penalty is not None and presence_penalty != 0.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] -= presence_penalty
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
@@ -897,6 +912,9 @@ class BaseGenerate:
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
if self.model.config.lm_head:
|
||||
return self.model.lm_head(input)
|
||||
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
|
||||
@@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
self.dtypes.add(dtype)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||
@@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
|
||||
833
comfy/text_encoders/qwen35.py
Normal file
833
comfy/text_encoders/qwen35.py
Normal file
@@ -0,0 +1,833 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
import math
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.qwen_vl
|
||||
|
||||
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
||||
|
||||
|
||||
def _qwen35_layer_types(n):
|
||||
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
||||
|
||||
@dataclass
|
||||
class Qwen35Config:
|
||||
vocab_size: int = 248320
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 24
|
||||
# Full attention params
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 2
|
||||
head_dim: int = 256
|
||||
partial_rotary_factor: float = 0.25
|
||||
# Linear attention (DeltaNet) params
|
||||
linear_num_key_heads: int = 16
|
||||
linear_num_value_heads: int = 16
|
||||
linear_key_head_dim: int = 128
|
||||
linear_value_head_dim: int = 128
|
||||
conv_kernel_size: int = 4
|
||||
# Shared params
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000000.0
|
||||
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
|
||||
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
|
||||
rms_norm_add: bool = True
|
||||
mlp_activation: str = "silu"
|
||||
qkv_bias: bool = False
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
|
||||
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
|
||||
transformer_type: str = "qwen35_2b"
|
||||
rope_dims: list = None
|
||||
rope_scale: float = None
|
||||
|
||||
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
|
||||
|
||||
QWEN35_MODELS = {
|
||||
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
|
||||
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
|
||||
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
|
||||
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||
}
|
||||
|
||||
|
||||
def _make_config(model_type, config_dict={}):
|
||||
overrides = QWEN35_MODELS.get(model_type, {}).copy()
|
||||
overrides.pop("vision", None)
|
||||
if "num_hidden_layers" in overrides:
|
||||
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
|
||||
overrides.update(config_dict)
|
||||
return Qwen35Config(**overrides)
|
||||
|
||||
|
||||
class RMSNormGated(RMSNorm):
|
||||
def forward(self, x, gate):
|
||||
return super().forward(x) * F.silu(gate.to(x.dtype))
|
||||
|
||||
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
|
||||
initial_dtype = query.dtype
|
||||
query = F.normalize(query, dim=-1)
|
||||
key = F.normalize(key, dim=-1)
|
||||
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size))
|
||||
key = F.pad(key, (0, 0, 0, pad_size))
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
total_sequence_length = sequence_length + pad_size
|
||||
scale = 1 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
||||
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
last_recurrent_state = (
|
||||
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
||||
if initial_state is None
|
||||
else initial_state.to(value)
|
||||
)
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
||||
|
||||
for i in range(0, total_sequence_length // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
||||
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :sequence_length]
|
||||
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
|
||||
|
||||
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
|
||||
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
|
||||
# weight: [channels, kernel_size]
|
||||
state_len = conv_state.shape[-1]
|
||||
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
|
||||
conv_state.copy_(combined[:, :, -state_len:])
|
||||
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
|
||||
if bias is not None:
|
||||
out = out + bias.unsqueeze(0).unsqueeze(-1)
|
||||
return F.silu(out).to(x.dtype)
|
||||
|
||||
|
||||
# GatedDeltaNet - Linear Attention Layer
|
||||
|
||||
class GatedDeltaNet(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
hidden = config.hidden_size
|
||||
self.num_key_heads = config.linear_num_key_heads
|
||||
self.num_value_heads = config.linear_num_value_heads
|
||||
self.key_head_dim = config.linear_key_head_dim
|
||||
self.value_head_dim = config.linear_value_head_dim
|
||||
self.conv_kernel_size = config.conv_kernel_size
|
||||
|
||||
key_dim = self.num_key_heads * self.key_head_dim
|
||||
value_dim = self.num_value_heads * self.value_head_dim
|
||||
self.key_dim = key_dim
|
||||
self.value_dim = value_dim
|
||||
conv_dim = key_dim * 2 + value_dim
|
||||
|
||||
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||
|
||||
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
|
||||
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
|
||||
|
||||
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, past_key_value=None, **kwargs):
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
use_recurrent = (
|
||||
past_key_value is not None
|
||||
and past_key_value[2] > 0
|
||||
and seq_len == 1
|
||||
)
|
||||
|
||||
# Projections (shared)
|
||||
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
|
||||
z = self.in_proj_z(x)
|
||||
b = self.in_proj_b(x)
|
||||
a = self.in_proj_a(x)
|
||||
|
||||
# Conv1d
|
||||
if use_recurrent:
|
||||
recurrent_state, conv_state, step_index = past_key_value
|
||||
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
|
||||
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
|
||||
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
|
||||
else:
|
||||
if past_key_value is not None:
|
||||
recurrent_state, conv_state, step_index = past_key_value
|
||||
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
||||
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
|
||||
# Split QKV and compute beta/g
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
|
||||
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
|
||||
beta = b.sigmoid()
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
|
||||
|
||||
# Delta rule
|
||||
if use_recurrent:
|
||||
# single-token path: work in [B, heads, dim] without seq dim
|
||||
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
|
||||
|
||||
if self.num_value_heads != self.num_key_heads:
|
||||
rep = self.num_value_heads // self.num_key_heads
|
||||
query = query.repeat_interleave(rep, dim=1)
|
||||
key = key.repeat_interleave(rep, dim=1)
|
||||
|
||||
scale = self.key_head_dim ** -0.5
|
||||
q = F.normalize(query.float(), dim=-1) * scale
|
||||
k = F.normalize(key.float(), dim=-1)
|
||||
v = value.float()
|
||||
beta_t = beta.reshape(batch_size, -1)
|
||||
g_t = g.reshape(batch_size, -1).exp()
|
||||
|
||||
# In-place state update: [B, heads, k_dim, v_dim]
|
||||
recurrent_state.mul_(g_t[:, :, None, None])
|
||||
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
|
||||
delta = (v - kv_mem) * beta_t[:, :, None]
|
||||
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
|
||||
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
|
||||
|
||||
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
|
||||
present_key_value = (recurrent_state, conv_state, step_index + 1)
|
||||
else:
|
||||
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
|
||||
|
||||
if self.num_value_heads != self.num_key_heads:
|
||||
rep = self.num_value_heads // self.num_key_heads
|
||||
query = query.repeat_interleave(rep, dim=2)
|
||||
key = key.repeat_interleave(rep, dim=2)
|
||||
|
||||
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||
query, key, value, g=g, beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=past_key_value is not None,
|
||||
)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
if last_recurrent_state is not None:
|
||||
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
|
||||
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
|
||||
|
||||
# Gated norm + output projection (shared)
|
||||
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
|
||||
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
|
||||
return output, present_key_value
|
||||
|
||||
|
||||
# GatedAttention - Full Attention with output gating
|
||||
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
|
||||
"""Compute RoPE frequencies for partial rotary embeddings."""
|
||||
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
if mrope_section is not None and position_ids.shape[0] == 3:
|
||||
mrope_section_2 = [s * 2 for s in mrope_section]
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
|
||||
|
||||
|
||||
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
|
||||
"""Apply RoPE to only the first rotary_dim dimensions."""
|
||||
xq_rot = xq[..., :rotary_dim]
|
||||
xq_pass = xq[..., rotary_dim:]
|
||||
xk_rot = xk[..., :rotary_dim]
|
||||
xk_pass = xk[..., rotary_dim:]
|
||||
|
||||
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
|
||||
|
||||
xq = torch.cat([xq_rot, xq_pass], dim=-1)
|
||||
xk = torch.cat([xk_rot, xk_pass], dim=-1)
|
||||
return xq, xk
|
||||
|
||||
|
||||
class GatedAttention(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.hidden_size = config.hidden_size
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
||||
|
||||
# q_proj outputs 2x: query + gate
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
# QK norms with (1+weight) scaling
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||
batch_size, seq_length, _ = x.shape
|
||||
|
||||
# Project Q (with gate), K, V
|
||||
qg = self.q_proj(x)
|
||||
# Split into query and gate: each is [B, seq, inner_size]
|
||||
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
|
||||
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
|
||||
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
|
||||
|
||||
xk = self.k_proj(x)
|
||||
xv = self.v_proj(x)
|
||||
|
||||
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
|
||||
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
|
||||
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply partial RoPE
|
||||
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
|
||||
|
||||
# KV cache
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
past_key, past_value, index = past_key_value
|
||||
num_tokens = xk.shape[2]
|
||||
if past_key.shape[2] >= (index + num_tokens):
|
||||
past_key[:, :, index:index + num_tokens] = xk
|
||||
past_value[:, :, index:index + num_tokens] = xv
|
||||
xk = past_key[:, :, :index + num_tokens]
|
||||
xv = past_value[:, :, :index + num_tokens]
|
||||
present_key_value = (past_key, past_value, index + num_tokens)
|
||||
else:
|
||||
if index > 0:
|
||||
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
# Expand KV heads for GQA
|
||||
if self.num_heads != self.num_kv_heads:
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
output = output * gate.sigmoid()
|
||||
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
|
||||
# Hybrid Transformer Block
|
||||
class Qwen35TransformerBlock(nn.Module):
|
||||
def __init__(self, config, index, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[index]
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
|
||||
else:
|
||||
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||
if self.layer_type == "linear_attention":
|
||||
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
|
||||
else:
|
||||
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
|
||||
|
||||
x = x + h
|
||||
x = x + self.mlp(self.post_attention_layernorm(x))
|
||||
return x, present_key_value
|
||||
|
||||
|
||||
# Qwen35 Transformer Backbone
|
||||
class Qwen35Transformer(Llama2_):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.normalize_in = False
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if layer.layer_type == "full_attention":
|
||||
if len(past_key_values) > i:
|
||||
return past_key_values[i][2]
|
||||
break
|
||||
return 0
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
|
||||
return precompute_partial_rope(
|
||||
self.config.head_dim, rotary_dim, position_ids,
|
||||
self.config.rope_theta, device=device,
|
||||
mrope_section=self.config.mrope_section,
|
||||
)
|
||||
|
||||
|
||||
# Vision Encoder
|
||||
class Qwen35VisionPatchEmbed(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.patch_size = config["patch_size"]
|
||||
self.temporal_patch_size = config["temporal_patch_size"]
|
||||
self.in_channels = config["in_channels"]
|
||||
self.embed_dim = config["hidden_size"]
|
||||
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
||||
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
target_dtype = self.proj.weight.dtype
|
||||
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
|
||||
|
||||
|
||||
class Qwen35VisionMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
|
||||
|
||||
|
||||
class Qwen35VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta=10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, seqlen):
|
||||
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class Qwen35VisionAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||
seq_length = x.shape[0]
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
|
||||
|
||||
# Process per-sequence attention
|
||||
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_splits = torch.split(query_states, lengths, dim=0)
|
||||
k_splits = torch.split(key_states, lengths, dim=0)
|
||||
v_splits = torch.split(value_states, lengths, dim=0)
|
||||
|
||||
attn_outputs = []
|
||||
for q, k, v in zip(q_splits, k_splits, v_splits):
|
||||
q = q.transpose(0, 1).unsqueeze(0)
|
||||
k = k.transpose(0, 1).unsqueeze(0)
|
||||
v = v.transpose(0, 1).unsqueeze(0)
|
||||
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
||||
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
return self.proj(attn_output)
|
||||
|
||||
|
||||
class Qwen35VisionBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class Qwen35VisionPatchMerger(nn.Module):
|
||||
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
merge_dim = hidden_size * (spatial_merge_size ** 2)
|
||||
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
|
||||
self.merge_dim = merge_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x).view(-1, self.merge_dim)
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
||||
|
||||
|
||||
class Qwen35VisionModel(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = config["spatial_merge_size"]
|
||||
self.patch_size = config["patch_size"]
|
||||
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.num_heads = config["num_heads"]
|
||||
self.num_position_embeddings = config["num_position_embeddings"]
|
||||
|
||||
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
|
||||
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
|
||||
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
|
||||
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config["depth"])
|
||||
])
|
||||
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
merge_size = self.spatial_merge_size
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||
freq_table = self.rotary_pos_emb(max_hw)
|
||||
device = freq_table.device
|
||||
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
|
||||
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||
offset = 0
|
||||
for num_frames, height, width in grid_thw_list:
|
||||
num_frames, height, width = int(num_frames), int(height), int(width)
|
||||
merged_h, merged_w = height // merge_size, width // merge_size
|
||||
block_rows = torch.arange(merged_h, device=device)
|
||||
block_cols = torch.arange(merged_w, device=device)
|
||||
intra_row = torch.arange(merge_size, device=device)
|
||||
intra_col = torch.arange(merge_size, device=device)
|
||||
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
||||
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
||||
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||
if num_frames > 1:
|
||||
coords = coords.repeat(num_frames, 1)
|
||||
num_tokens = coords.shape[0]
|
||||
pos_ids[offset:offset + num_tokens] = coords
|
||||
offset += num_tokens
|
||||
embeddings = freq_table[pos_ids]
|
||||
embeddings = embeddings.flatten(1)
|
||||
return embeddings
|
||||
|
||||
def fast_pos_embed_interpolate(self, grid_thw):
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_ts = [int(row[0]) for row in grid_thw_list]
|
||||
grid_hs = [int(row[1]) for row in grid_thw_list]
|
||||
grid_ws = [int(row[2]) for row in grid_thw_list]
|
||||
device = self.pos_embed.weight.device
|
||||
idx_list = [[] for _ in range(4)]
|
||||
weight_list = [[] for _ in range(4)]
|
||||
for t, h, w in grid_thw_list:
|
||||
h, w = int(h), int(w)
|
||||
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
|
||||
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
|
||||
h_idxs_floor = h_idxs.int()
|
||||
w_idxs_floor = w_idxs.int()
|
||||
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
dh = h_idxs - h_idxs_floor
|
||||
dw = w_idxs - w_idxs_floor
|
||||
base_h = h_idxs_floor * self.num_grid_per_side
|
||||
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
||||
indices = [
|
||||
(base_h[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h[None].T + w_idxs_ceil[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
|
||||
]
|
||||
weights = [
|
||||
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||
((1 - dh)[None].T * dw[None]).flatten(),
|
||||
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||
(dh[None].T * dw[None]).flatten(),
|
||||
]
|
||||
for j in range(4):
|
||||
idx_list[j].extend(indices[j].tolist())
|
||||
weight_list[j].extend(weights[j].tolist())
|
||||
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
|
||||
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
|
||||
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
|
||||
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||
patch_pos_embeds_permute = []
|
||||
merge_size = self.spatial_merge_size
|
||||
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
|
||||
pos_embed = pos_embed.repeat(t, 1)
|
||||
pos_embed = (
|
||||
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.flatten(0, 4)
|
||||
)
|
||||
patch_pos_embeds_permute.append(pos_embed)
|
||||
return torch.cat(patch_pos_embeds_permute)
|
||||
|
||||
def forward(self, x, grid_thw):
|
||||
x = self.patch_embed(x)
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
||||
x = x + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
seq_len = x.shape[0]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().unsqueeze(-2)
|
||||
sin = emb.sin().unsqueeze(-2)
|
||||
sin_half = sin.shape[-1] // 2
|
||||
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
merged = self.merger(x)
|
||||
return merged
|
||||
|
||||
# Model Wrapper
|
||||
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
model_type = "qwen35_2b"
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = _make_config(self.model_type, config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
|
||||
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
|
||||
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
|
||||
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
|
||||
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
model_config = self.model.config
|
||||
past_key_values = []
|
||||
for i in range(model_config.num_hidden_layers):
|
||||
if model_config.layer_types[i] == "linear_attention":
|
||||
recurrent_state = torch.zeros(
|
||||
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
|
||||
device=device, dtype=torch.float32
|
||||
)
|
||||
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
|
||||
conv_state = torch.zeros(
|
||||
[batch, conv_dim, model_config.conv_kernel_size - 1],
|
||||
device=device, dtype=execution_dtype
|
||||
)
|
||||
past_key_values.append((recurrent_state, conv_state, 0))
|
||||
else:
|
||||
past_key_values.append((
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
0
|
||||
))
|
||||
return past_key_values
|
||||
|
||||
# Tokenizer and Text Encoder Wrappers
|
||||
|
||||
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
|
||||
from transformers import Qwen2Tokenizer
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
|
||||
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
|
||||
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||
image = kwargs.get("image", None)
|
||||
if image is not None and len(images) == 0:
|
||||
images = [image]
|
||||
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
skip_template = True
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is None:
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
if not thinking:
|
||||
llama_text += "<think>\n</think>\n"
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 248056: # <|image_pad|>
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
class Qwen35ClipModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
|
||||
class Qwen35_(Qwen35):
|
||||
pass
|
||||
Qwen35_.model_type = model_type
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
|
||||
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Qwen35TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
|
||||
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def tokenizer(model_type="qwen35_2b"):
|
||||
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
||||
return Qwen35ImageTokenizer_
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
|
||||
class Qwen35TEModel_(Qwen35TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
||||
return Qwen35TEModel_
|
||||
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
File diff suppressed because one or more lines are too long
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
CurvePoint,
|
||||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -13,4 +17,8 @@ __all__ = [
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
@@ -7,4 +8,8 @@ __all__ = [
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CurvePoint = tuple[float, float]
|
||||
|
||||
|
||||
class CurveInput(ABC):
|
||||
"""Abstract base class for curve inputs.
|
||||
|
||||
Subclasses represent different curve representations (control-point
|
||||
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||
uniform evaluation interface to downstream nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def points(self) -> list[CurvePoint]:
|
||||
"""The control points that define this curve."""
|
||||
|
||||
@abstractmethod
|
||||
def interp(self, x: float) -> float:
|
||||
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||
|
||||
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||
"""Vectorised evaluation over a numpy array of x values.
|
||||
|
||||
Subclasses should override this for better performance. The default
|
||||
falls back to scalar ``interp`` calls.
|
||||
"""
|
||||
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> CurveInput:
|
||||
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||
|
||||
Accepts:
|
||||
- A ``CurveInput`` instance (returned as-is).
|
||||
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||
"""
|
||||
if isinstance(data, CurveInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
raw_points = data["points"]
|
||||
interpolation = data.get("interpolation", "monotone_cubic")
|
||||
else:
|
||||
raw_points = data
|
||||
interpolation = "monotone_cubic"
|
||||
points = [(float(x), float(y)) for x, y in raw_points]
|
||||
if interpolation == "linear":
|
||||
return LinearCurve(points)
|
||||
if interpolation != "monotone_cubic":
|
||||
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||
return MonotoneCubicCurve(points)
|
||||
|
||||
|
||||
class MonotoneCubicCurve(CurveInput):
|
||||
"""Monotone cubic Hermite interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||
backend evaluation matches the editor preview exactly.
|
||||
|
||||
All heavy work (sorting, slope computation) happens once at construction.
|
||||
``interp_array`` is fully vectorised with numpy.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
self._slopes = self._compute_slopes()
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def _compute_slopes(self) -> np.ndarray:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return np.zeros(n, dtype=np.float64)
|
||||
|
||||
dx = np.diff(xs)
|
||||
dy = np.diff(ys)
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||
|
||||
slopes = np.empty(n, dtype=np.float64)
|
||||
slopes[0] = deltas[0]
|
||||
slopes[-1] = deltas[-1]
|
||||
for i in range(1, n - 1):
|
||||
if deltas[i - 1] * deltas[i] <= 0:
|
||||
slopes[i] = 0.0
|
||||
else:
|
||||
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||
|
||||
for i in range(n - 1):
|
||||
if deltas[i] == 0:
|
||||
slopes[i] = 0.0
|
||||
slopes[i + 1] = 0.0
|
||||
else:
|
||||
alpha = slopes[i] / deltas[i]
|
||||
beta = slopes[i + 1] / deltas[i]
|
||||
s = alpha * alpha + beta * beta
|
||||
if s > 9:
|
||||
t = 3 / math.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
return slopes
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
if x <= xs[0]:
|
||||
return float(ys[0])
|
||||
if x >= xs[-1]:
|
||||
return float(ys[-1])
|
||||
|
||||
hi = int(np.searchsorted(xs, x, side='right'))
|
||||
hi = min(hi, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
if dx == 0:
|
||||
return float(ys[lo])
|
||||
|
||||
t = (x - xs[lo]) / dx
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
"""Fully vectorised evaluation using numpy."""
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if n == 1:
|
||||
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||
|
||||
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
|
||||
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MonotoneCubicCurve(points={self._points})"
|
||||
|
||||
|
||||
class LinearCurve(CurveInput):
|
||||
"""Piecewise linear interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createLinearInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
return float(np.interp(x, xs, ys))
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
if len(self._xs) == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if len(self._xs) == 1:
|
||||
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||
return np.interp(xs_in, self._xs, self._ys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LinearCurve(points={self._points})"
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from comfy.samplers import CFGGuider, Sampler
|
||||
from comfy.sd import CLIP, VAE
|
||||
from comfy.sd import StyleModel as StyleModel_
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
@@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
from comfy_api.input import CurvePoint
|
||||
if TYPE_CHECKING:
|
||||
Type = CurveInput_
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
@@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.default is not None:
|
||||
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
Type = list[int]
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
@@ -1360,6 +1373,7 @@ class NodeInfoV1:
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
essentials_category: str=None
|
||||
has_intermediate_output: bool=None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1483,6 +1497,16 @@ class Schema:
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
essentials_category: str | None = None
|
||||
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
|
||||
has_intermediate_output: bool=False
|
||||
"""Flags this node as having intermediate output that should persist across page refreshes.
|
||||
|
||||
Nodes with this flag behave like output nodes (their UI results are cached and resent
|
||||
to the frontend) but do NOT automatically get added to the execution list. This means
|
||||
they will only execute if they are on the dependency path of a real output node.
|
||||
|
||||
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
|
||||
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
|
||||
"""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@@ -1582,6 +1606,7 @@ class Schema:
|
||||
category=self.category,
|
||||
description=self.description,
|
||||
output_node=self.is_output_node,
|
||||
has_intermediate_output=self.has_intermediate_output,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
dev_only=self.is_dev_only,
|
||||
@@ -1873,6 +1898,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._OUTPUT_NODE
|
||||
|
||||
_HAS_INTERMEDIATE_OUTPUT = None
|
||||
@final
|
||||
@classproperty
|
||||
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
|
||||
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._HAS_INTERMEDIATE_OUTPUT
|
||||
|
||||
_INPUT_IS_LIST = None
|
||||
@final
|
||||
@classproperty
|
||||
@@ -1965,6 +1998,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._API_NODE = schema.is_api_node
|
||||
if cls._OUTPUT_NODE is None:
|
||||
cls._OUTPUT_NODE = schema.is_output_node
|
||||
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
|
||||
if cls._INPUT_IS_LIST is None:
|
||||
cls._INPUT_IS_LIST = schema.is_input_list
|
||||
if cls._NOT_IDEMPOTENT is None:
|
||||
@@ -2240,5 +2275,6 @@ __all__ = [
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -201,6 +201,16 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
|
||||
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
if not thought:
|
||||
# No images generated --> extract text response for a meaningful error
|
||||
model_message = get_text_from_response(response).strip()
|
||||
if model_message:
|
||||
raise ValueError(f"Gemini did not generate an image. Model response: {model_message}")
|
||||
raise ValueError(
|
||||
"Gemini did not generate an image. "
|
||||
"Try rephrasing your prompt or changing the response modality to 'IMAGE+TEXT' "
|
||||
"to see the model's reasoning."
|
||||
)
|
||||
return torch.zeros((1, 1024, 1024, 4))
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
@@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
widgets.upscale = "enabled" ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0457, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
$enabled := widgets.upscale = "enabled";
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$isFast
|
||||
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||
: $enabled ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
$enabled := widgets.upscale = "enabled";
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$isFast
|
||||
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||
: $enabled ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@@ -38,6 +38,7 @@ from comfy_api_nodes.util import (
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
"Starlight Precise 2.5": "slp-2.5",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
@@ -475,6 +474,10 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(node_id)
|
||||
return await self._set_immediate(node_id, value)
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
BasicCache.set_local(self, node_id, value)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
await super()._ensure_subcache(node_id, children_ids)
|
||||
@@ -489,15 +492,10 @@ class LRUCache(BasicCache):
|
||||
return self
|
||||
|
||||
|
||||
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||
#Small baseline weight used when a cache entry has no measurable CPU tensors.
|
||||
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
|
||||
|
||||
RAM_CACHE_HYSTERESIS = 1.1
|
||||
|
||||
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||
|
||||
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||
RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
|
||||
|
||||
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||
#in constantly changing setups.
|
||||
@@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return await super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
def set_local(self, node_id, value):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
super().set_local(node_id, value)
|
||||
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > ram_headroom:
|
||||
def ram_release(self, target):
|
||||
if psutil.virtual_memory().available >= target:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, (outputs, _), in self.cache.items():
|
||||
for key, cache_entry in self.cache.items():
|
||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||
|
||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||
@@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache):
|
||||
if outputs is None:
|
||||
return
|
||||
for output in outputs:
|
||||
if isinstance(output, list):
|
||||
if isinstance(output, (list, tuple)):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||
#be high value intermediates
|
||||
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||
elif hasattr(output, "get_ram_usage"):
|
||||
ram_usage += output.get_ram_usage()
|
||||
scan_list_for_ram_usage(outputs)
|
||||
ram_usage += output.numel() * output.element_size()
|
||||
scan_list_for_ram_usage(cache_entry.outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
|
||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
while psutil.virtual_memory().available < target and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
gc.collect()
|
||||
self.used_generation.pop(key, None)
|
||||
self.timestamps.pop(key, None)
|
||||
self.children.pop(key, None)
|
||||
|
||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.input import CurveInput
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CurveEditor(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CurveEditor",
|
||||
display_name="Curve Editor",
|
||||
category="utils",
|
||||
inputs=[
|
||||
io.Curve.Input("curve"),
|
||||
io.Histogram.Input("histogram", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Curve.Output("curve"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||
result = CurveInput.from_raw(curve)
|
||||
|
||||
ui = {}
|
||||
if histogram is not None:
|
||||
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||
|
||||
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||
|
||||
|
||||
class CurveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [CurveEditor]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return CurveExtension()
|
||||
@@ -87,7 +87,9 @@ class SizeModeInput(TypedDict):
|
||||
|
||||
|
||||
MAX_IMAGES = 5 # u_image0-4
|
||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
||||
MAX_BOOLS = 10 # u_bool0-9
|
||||
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
|
||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||
@@ -497,6 +499,8 @@ def _render_shader_batch(
|
||||
image_batches: list[list[np.ndarray]],
|
||||
floats: list[float],
|
||||
ints: list[int],
|
||||
bools: list[bool] | None = None,
|
||||
curves: list[np.ndarray] | None = None,
|
||||
) -> list[list[np.ndarray]]:
|
||||
"""
|
||||
Render a fragment shader for multiple batches efficiently.
|
||||
@@ -511,6 +515,8 @@ def _render_shader_batch(
|
||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||
floats: List of float uniforms
|
||||
ints: List of int uniforms
|
||||
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
||||
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
||||
|
||||
Returns:
|
||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||
@@ -533,11 +539,17 @@ def _render_shader_batch(
|
||||
# Detect multi-pass rendering
|
||||
num_passes = _detect_pass_count(fragment_code)
|
||||
|
||||
if bools is None:
|
||||
bools = []
|
||||
if curves is None:
|
||||
curves = []
|
||||
|
||||
# Track resources for cleanup
|
||||
program = None
|
||||
fbo = None
|
||||
output_textures = []
|
||||
input_textures = []
|
||||
curve_textures = []
|
||||
ping_pong_textures = []
|
||||
ping_pong_fbos = []
|
||||
|
||||
@@ -624,6 +636,28 @@ def _render_shader_batch(
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, v)
|
||||
|
||||
for i, v in enumerate(bools):
|
||||
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, 1 if v else 0)
|
||||
|
||||
# Create 1D LUT textures for curves (bound after image texture units)
|
||||
for i, lut in enumerate(curves):
|
||||
tex = gl.glGenTextures(1)
|
||||
curve_textures.append(tex)
|
||||
unit = MAX_IMAGES + i
|
||||
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||
|
||||
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, unit)
|
||||
|
||||
# Get u_pass uniform location for multi-pass
|
||||
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
||||
|
||||
@@ -718,6 +752,8 @@ def _render_shader_batch(
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
@@ -754,6 +790,20 @@ class GLSLShader(io.ComfyNode):
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
bool_template = io.Autogrow.TemplatePrefix(
|
||||
io.Boolean.Input("bool", default=False),
|
||||
prefix="u_bool",
|
||||
min=0,
|
||||
max=MAX_BOOLS,
|
||||
)
|
||||
|
||||
curve_template = io.Autogrow.TemplatePrefix(
|
||||
io.Curve.Input("curve"),
|
||||
prefix="u_curve",
|
||||
min=0,
|
||||
max=MAX_CURVES,
|
||||
)
|
||||
|
||||
return io.Schema(
|
||||
node_id="GLSLShader",
|
||||
display_name="GLSL Shader",
|
||||
@@ -762,6 +812,8 @@ class GLSLShader(io.ComfyNode):
|
||||
"Apply GLSL ES fragment shaders to images. "
|
||||
"u_resolution (vec2) is always available."
|
||||
),
|
||||
is_experimental=True,
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
@@ -796,6 +848,8 @@ class GLSLShader(io.ComfyNode):
|
||||
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
||||
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
||||
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
||||
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
||||
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
||||
@@ -813,13 +867,19 @@ class GLSLShader(io.ComfyNode):
|
||||
images: io.Autogrow.Type,
|
||||
floats: io.Autogrow.Type = None,
|
||||
ints: io.Autogrow.Type = None,
|
||||
bools: io.Autogrow.Type = None,
|
||||
curves: io.Autogrow.Type = None,
|
||||
**kwargs,
|
||||
) -> io.NodeOutput:
|
||||
|
||||
image_list = [v for v in images.values() if v is not None]
|
||||
float_list = (
|
||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||
)
|
||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
||||
|
||||
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
||||
|
||||
if not image_list:
|
||||
raise ValueError("At least one input image is required")
|
||||
@@ -846,6 +906,8 @@ class GLSLShader(io.ComfyNode):
|
||||
image_batches,
|
||||
float_list,
|
||||
int_list,
|
||||
bool_list,
|
||||
curve_luts,
|
||||
)
|
||||
|
||||
# Collect outputs into tensors
|
||||
|
||||
@@ -59,6 +59,7 @@ class ImageCropV2(IO.ComfyNode):
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
|
||||
92
comfy_extras/nodes_number_convert.py
Normal file
92
comfy_extras/nodes_number_convert.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Number Convert node for unified numeric type conversion.
|
||||
|
||||
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||
inputs into FLOAT and INT outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class NumberConvertNode(io.ComfyNode):
|
||||
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ComfyNumberConvert",
|
||||
display_name="Number Convert",
|
||||
category="math",
|
||||
search_aliases=[
|
||||
"int to float", "float to int", "number convert",
|
||||
"int2float", "float2int", "cast", "parse number",
|
||||
"string to number", "bool to int",
|
||||
],
|
||||
inputs=[
|
||||
io.MultiType.Input(
|
||||
"value",
|
||||
[io.Int, io.Float, io.String, io.Boolean],
|
||||
display_name="value",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Float.Output(display_name="FLOAT"),
|
||||
io.Int.Output(display_name="INT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
if isinstance(value, bool):
|
||||
float_val = 1.0 if value else 0.0
|
||||
int_val = 1 if value else 0
|
||||
elif isinstance(value, int):
|
||||
float_val = float(value)
|
||||
int_val = value
|
||||
elif isinstance(value, float):
|
||||
float_val = value
|
||||
int_val = int(value)
|
||||
elif isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
raise ValueError("Cannot convert empty string to number.")
|
||||
try:
|
||||
float_val = float(text)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Cannot convert string to number: {value!r}"
|
||||
) from None
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
try:
|
||||
int_val = int(text)
|
||||
except ValueError:
|
||||
int_val = int(float_val)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported input type: {type(value).__name__}"
|
||||
)
|
||||
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
|
||||
return io.NodeOutput(float_val, int_val)
|
||||
|
||||
|
||||
class NumberConvertExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [NumberConvertNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||
return NumberConvertExtension()
|
||||
@@ -30,6 +30,7 @@ class PainterNode(io.ComfyNode):
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
has_intermediate_output=True,
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
|
||||
@@ -67,11 +67,11 @@ class Blend(io.ComfyNode):
|
||||
def g(cls, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None, dtype=torch.float32):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
return (g / g.sum()).to(dtype)
|
||||
|
||||
class Blur(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -99,7 +99,7 @@ class Blur(io.ComfyNode):
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = blur_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype).repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||
@@ -200,7 +200,7 @@ class Sharpen(io.ComfyNode):
|
||||
image = image.to(comfy.model_management.get_torch_device())
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype) * -(alpha*10)
|
||||
kernel = kernel.to(dtype=image.dtype)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
|
||||
154
comfy_extras/nodes_rtdetr.py
Normal file
154
comfy_extras/nodes_rtdetr.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from comfy.ldm.rt_detr.rtdetr_v4 import COCO_CLASSES
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
|
||||
class RTDETR_detect(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RTDETR_detect",
|
||||
display_name="RT-DETR Detect",
|
||||
category="detection/",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
||||
inputs=[
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Image.Input("image", display_name="image"),
|
||||
io.Float.Input("threshold", display_name="threshold", default=0.5),
|
||||
io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."),
|
||||
io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."),
|
||||
],
|
||||
outputs=[
|
||||
io.BoundingBox.Output("bboxes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
|
||||
B, H, W, C = image.shape
|
||||
|
||||
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
|
||||
|
||||
all_bbox_dicts = []
|
||||
|
||||
for det in results:
|
||||
keep = det['scores'] > threshold
|
||||
boxes = det['boxes'][keep].cpu()
|
||||
labels = det['labels'][keep].cpu()
|
||||
scores = det['scores'][keep].cpu()
|
||||
|
||||
bbox_dicts = [
|
||||
{
|
||||
"x": float(box[0]),
|
||||
"y": float(box[1]),
|
||||
"width": float(box[2] - box[0]),
|
||||
"height": float(box[3] - box[1]),
|
||||
"label": COCO_CLASSES[int(label)],
|
||||
"score": float(score)
|
||||
}
|
||||
for box, label, score in zip(boxes, labels, scores)
|
||||
if class_name == "all" or COCO_CLASSES[int(label)] == class_name
|
||||
]
|
||||
bbox_dicts.sort(key=lambda d: d["score"], reverse=True)
|
||||
all_bbox_dicts.append(bbox_dicts[:max_detections])
|
||||
|
||||
return io.NodeOutput(all_bbox_dicts)
|
||||
|
||||
|
||||
class DrawBBoxes(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DrawBBoxes",
|
||||
display_name="Draw BBoxes",
|
||||
category="detection/",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
|
||||
inputs=[
|
||||
io.Image.Input("image", optional=True),
|
||||
io.BoundingBox.Input("bboxes", force_input=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("out_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, bboxes, image=None) -> io.NodeOutput:
|
||||
# Normalise to list[list[dict]], then fit to batch size B.
|
||||
B = image.shape[0] if image is not None else 1
|
||||
if isinstance(bboxes, dict):
|
||||
bboxes = [[bboxes]]
|
||||
elif not isinstance(bboxes, list) or not bboxes:
|
||||
bboxes = [[]]
|
||||
elif isinstance(bboxes[0], dict):
|
||||
bboxes = [bboxes] # flat list → same detections for every image
|
||||
|
||||
if len(bboxes) == 1:
|
||||
bboxes = bboxes * B
|
||||
bboxes = (bboxes + [[]] * B)[:B]
|
||||
|
||||
if image is None:
|
||||
B = len(bboxes)
|
||||
max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640)
|
||||
max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640)
|
||||
image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32)
|
||||
|
||||
all_out_images = []
|
||||
for i in range(B):
|
||||
detections = bboxes[i]
|
||||
if detections:
|
||||
boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections])
|
||||
labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections]
|
||||
scores = torch.tensor([d.get("score", 1.0) for d in detections])
|
||||
else:
|
||||
boxes = torch.zeros((0, 4))
|
||||
labels = []
|
||||
scores = torch.zeros((0,))
|
||||
|
||||
pil_image = image[i].movedim(-1, 0)
|
||||
img = ToPILImage()(pil_image)
|
||||
if detections:
|
||||
img = cls.draw_detections(img, boxes, labels, scores)
|
||||
all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1))
|
||||
|
||||
out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput(out_images)
|
||||
|
||||
@classmethod
|
||||
def draw_detections(cls, img, boxes, labels, scores):
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
font = ImageFont.truetype('arial.ttf', 16)
|
||||
except Exception:
|
||||
font = ImageFont.load_default()
|
||||
colors = [(255,0,0),(0,200,0),(0,0,255),(255,165,0),(128,0,128),
|
||||
(0,255,255),(255,20,147),(100,149,237)]
|
||||
for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()):
|
||||
x1, y1, x2, y2 = box.tolist()
|
||||
color_idx = COCO_CLASSES.index(label) if label is not None else 0
|
||||
c = colors[color_idx % len(colors)]
|
||||
draw.rectangle([x1, y1, x2, y2], outline=c, width=3)
|
||||
if label is not None:
|
||||
draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font)
|
||||
return img
|
||||
|
||||
|
||||
class RTDETRExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
RTDETR_detect,
|
||||
DrawBBoxes,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> RTDETRExtension:
|
||||
return RTDETRExtension()
|
||||
@@ -661,6 +661,7 @@ class CropByBBoxes(io.ComfyNode):
|
||||
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
|
||||
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
|
||||
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
|
||||
io.Combo.Input("keep_aspect", options=["stretch", "pad"], default="stretch", tooltip="Whether to stretch the crop to fit the output size, or pad with black pixels to preserve aspect ratio."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(tooltip="All crops stacked into a single image batch."),
|
||||
@@ -668,7 +669,7 @@ class CropByBBoxes(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput:
|
||||
def execute(cls, image, bboxes, output_width, output_height, padding, keep_aspect="stretch") -> io.NodeOutput:
|
||||
total_frames = image.shape[0]
|
||||
img_h = image.shape[1]
|
||||
img_w = image.shape[2]
|
||||
@@ -716,7 +717,19 @@ class CropByBBoxes(io.ComfyNode):
|
||||
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
|
||||
|
||||
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||
|
||||
if keep_aspect == "pad":
|
||||
crop_h, crop_w = y2 - y1, x2 - x1
|
||||
scale = min(output_width / crop_w, output_height / crop_h)
|
||||
scaled_w = int(round(crop_w * scale))
|
||||
scaled_h = int(round(crop_h * scale))
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||
pad_left = (output_width - scaled_w) // 2
|
||||
pad_top = (output_height - scaled_h) // 2
|
||||
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
|
||||
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||
else: # "stretch"
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||
crops.append(resized)
|
||||
|
||||
if not crops:
|
||||
|
||||
@@ -15,6 +15,7 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
||||
io.Float.Input("presence_penalty", optional=True, default=0.0, min=0.0, max=5.0, step=0.01),
|
||||
]
|
||||
),
|
||||
io.DynamicCombo.Option(
|
||||
@@ -25,7 +26,7 @@ class TextGenerate(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id="TextGenerate",
|
||||
category="textgen/",
|
||||
category="textgen",
|
||||
search_aliases=["LLM", "gemma"],
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
@@ -33,6 +34,7 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(display_name="generated_text"),
|
||||
@@ -40,9 +42,9 @@ class TextGenerate(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1)
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
@@ -52,6 +54,7 @@ class TextGenerate(io.ComfyNode):
|
||||
min_p = sampling_mode.get("min_p", 0.0)
|
||||
seed = sampling_mode.get("seed", None)
|
||||
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
||||
presence_penalty = sampling_mode.get("presence_penalty", 0.0)
|
||||
|
||||
generated_ids = clip.generate(
|
||||
tokens,
|
||||
@@ -62,6 +65,7 @@ class TextGenerate(io.ComfyNode):
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
@@ -156,12 +160,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
|
||||
@@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for lora.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"quantized_backward",
|
||||
default=False,
|
||||
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"algorithm",
|
||||
options=list(adapter_maps.keys()),
|
||||
@@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
quantized_backward,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
checkpoint_depth,
|
||||
@@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed = seed[0]
|
||||
training_dtype = training_dtype[0]
|
||||
lora_dtype = lora_dtype[0]
|
||||
quantized_backward = quantized_backward[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
offloading = offloading[0]
|
||||
@@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
|
||||
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
latents = _process_latents_bucket_mode(latents)
|
||||
@@ -1137,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
@@ -1145,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode):
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
use_grad_scaler = True
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
@@ -1156,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
|
||||
38
execution.py
38
execution.py
@@ -411,6 +411,19 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def _is_intermediate_output(dynprompt, node_id):
|
||||
class_type = dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||
|
||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||||
if server.client_id is None:
|
||||
return
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[node_id] = cached.ui
|
||||
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
@@ -421,11 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
cached = await caches.outputs.get(unique_id)
|
||||
if cached is not None:
|
||||
if server.client_id is not None:
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[unique_id] = cached.ui
|
||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, cached)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
@@ -715,6 +724,9 @@ class PromptExecutor:
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
|
||||
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
|
||||
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
@@ -764,9 +776,22 @@ class PromptExecutor:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
|
||||
if self.cache_type == CacheType.RAM_PRESSURE:
|
||||
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
|
||||
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
# Send cached UI for intermediate output nodes that weren't executed
|
||||
for node_id in dynamic_prompt.all_node_ids():
|
||||
if node_id in executed:
|
||||
continue
|
||||
if not _is_intermediate_output(dynamic_prompt, node_id):
|
||||
continue
|
||||
cached = await self.caches.outputs.get(node_id)
|
||||
if cached is not None:
|
||||
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
@@ -782,6 +807,7 @@ class PromptExecutor:
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
finally:
|
||||
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
|
||||
51
main.py
51
main.py
@@ -9,6 +9,8 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import register_output_files
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
@@ -192,7 +194,6 @@ if 'torch' in sys.modules:
|
||||
|
||||
|
||||
import comfy.utils
|
||||
from app.assets.seeder import asset_seeder
|
||||
|
||||
import execution
|
||||
import server
|
||||
@@ -240,17 +241,53 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||
"""Extract absolute file paths for output items from a history result."""
|
||||
paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for node_output in history_result.get("outputs", {}).values():
|
||||
for items in node_output.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type not in ("output", "temp"):
|
||||
continue
|
||||
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||
if base_dir is None:
|
||||
continue
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
filename = item.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
abs_path = os.path.abspath(
|
||||
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||
)
|
||||
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||
continue
|
||||
if abs_path not in seen:
|
||||
seen.add(abs_path)
|
||||
paths.append(abs_path)
|
||||
return paths
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_ram = args.cache_ram
|
||||
if cache_ram < 0:
|
||||
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
|
||||
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_ram > 0:
|
||||
elif cache_ram > 0:
|
||||
cache_type = execution.CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.NONE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
@@ -274,6 +311,7 @@ def prompt_worker(q, server_instance):
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
@@ -296,6 +334,10 @@ def prompt_worker(q, server_instance):
|
||||
else:
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
paths = _collect_output_absolute_paths(e.history_result)
|
||||
register_output_files(paths, job_id=prompt_id)
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
@@ -317,6 +359,9 @@ def prompt_worker(q, server_instance):
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b8
|
||||
comfyui_manager==4.1
|
||||
|
||||
3
nodes.py
3
nodes.py
@@ -2454,7 +2454,10 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
"nodes_rtdetr.py"
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.26
|
||||
comfyui-workflow-templates==0.9.39
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@@ -709,6 +709,11 @@ class PromptServer():
|
||||
else:
|
||||
info['output_node'] = False
|
||||
|
||||
if hasattr(obj_class, 'HAS_INTERMEDIATE_OUTPUT') and obj_class.HAS_INTERMEDIATE_OUTPUT == True:
|
||||
info['has_intermediate_output'] = True
|
||||
else:
|
||||
info['has_intermediate_output'] = False
|
||||
|
||||
if hasattr(obj_class, 'CATEGORY'):
|
||||
info['category'] = obj_class.CATEGORY
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Base
|
||||
@@ -23,6 +23,21 @@ def db_engine():
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine_fk():
|
||||
"""In-memory SQLite engine with foreign key enforcement enabled."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def _set_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(db_engine):
|
||||
"""Session fixture for tests that need direct DB access."""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for asset enrichment (mime_type and hash population)."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.services.file_utils import get_mtime_ns
|
||||
from app.assets.scanner import (
|
||||
ENRICHMENT_HASHED,
|
||||
ENRICHMENT_METADATA,
|
||||
@@ -20,6 +22,13 @@ def _create_stub_asset(
|
||||
name: str | None = None,
|
||||
) -> tuple[Asset, AssetReference]:
|
||||
"""Create a stub asset with reference for testing enrichment."""
|
||||
# Use the real file's mtime so the optimistic guard in enrich_asset passes
|
||||
try:
|
||||
stat_result = os.stat(file_path, follow_symlinks=True)
|
||||
mtime_ns = get_mtime_ns(stat_result)
|
||||
except OSError:
|
||||
mtime_ns = 1234567890000000000
|
||||
|
||||
asset = Asset(
|
||||
id=asset_id,
|
||||
hash=None,
|
||||
@@ -35,7 +44,7 @@ def _create_stub_asset(
|
||||
name=name or f"test-asset-{asset_id}",
|
||||
owner_id="system",
|
||||
file_path=file_path,
|
||||
mtime_ns=1234567890000000000,
|
||||
mtime_ns=mtime_ns,
|
||||
enrichment_level=ENRICHMENT_STUB,
|
||||
)
|
||||
session.add(ref)
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""Tests for ingest services."""
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session as SASession, Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, Tag
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||
from app.assets.database.queries import get_reference_tags
|
||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
||||
from app.assets.services.ingest import (
|
||||
_ingest_file_from_path,
|
||||
_register_existing_asset,
|
||||
ingest_existing_file,
|
||||
)
|
||||
|
||||
|
||||
class TestIngestFileFromPath:
|
||||
@@ -235,3 +241,42 @@ class TestRegisterExistingAsset:
|
||||
|
||||
assert result.created is True
|
||||
assert set(result.tags) == {"alpha", "beta"}
|
||||
|
||||
|
||||
class TestIngestExistingFileTagFK:
|
||||
"""Regression: ingest_existing_file must seed Tag rows before inserting
|
||||
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
|
||||
|
||||
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
|
||||
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
|
||||
before they can be referenced in asset_reference_tags."""
|
||||
|
||||
@contextmanager
|
||||
def _create_session():
|
||||
with SASession(db_engine_fk) as sess:
|
||||
yield sess
|
||||
|
||||
file_path = temp_dir / "output.png"
|
||||
file_path.write_bytes(b"image data")
|
||||
|
||||
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||
patch(
|
||||
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
|
||||
return_value=("output.png", ["output"]),
|
||||
):
|
||||
result = ingest_existing_file(
|
||||
abs_path=str(file_path),
|
||||
extra_tags=["my-job"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
with SASession(db_engine_fk) as sess:
|
||||
tag_names = {t.name for t in sess.query(Tag).all()}
|
||||
assert "output" in tag_names
|
||||
assert "my-job" in tag_names
|
||||
|
||||
ref_tags = sess.query(AssetReferenceTag).all()
|
||||
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||
assert "output" in ref_tag_names
|
||||
assert "my-job" in ref_tag_names
|
||||
|
||||
81
tests-unit/assets_test/services/test_path_utils.py
Normal file
81
tests-unit/assets_test/services/test_path_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for path_utils – asset category resolution."""
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_dirs():
|
||||
"""Create temporary input, output, and temp directories."""
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
temp_dir = root_path / "temp"
|
||||
models_dir = root_path / "models" / "checkpoints"
|
||||
for d in (input_dir, output_dir, temp_dir, models_dir):
|
||||
d.mkdir(parents=True)
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(models_dir)])],
|
||||
):
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
"models": models_dir,
|
||||
}
|
||||
|
||||
|
||||
class TestGetAssetCategoryAndRelativePath:
|
||||
def test_input_file(self, fake_dirs):
|
||||
f = fake_dirs["input"] / "photo.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "input"
|
||||
assert rel == "photo.png"
|
||||
|
||||
def test_output_file(self, fake_dirs):
|
||||
f = fake_dirs["output"] / "result.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "output"
|
||||
assert rel == "result.png"
|
||||
|
||||
def test_temp_file(self, fake_dirs):
|
||||
"""Regression: temp files must be categorised, not raise ValueError."""
|
||||
f = fake_dirs["temp"] / "GLSLShader_output_00004_.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "temp"
|
||||
assert rel == "GLSLShader_output_00004_.png"
|
||||
|
||||
def test_temp_file_in_subfolder(self, fake_dirs):
|
||||
sub = fake_dirs["temp"] / "sub"
|
||||
sub.mkdir()
|
||||
f = sub / "ComfyUI_temp_tczip_00004_.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "temp"
|
||||
assert os.path.normpath(rel) == os.path.normpath("sub/ComfyUI_temp_tczip_00004_.png")
|
||||
|
||||
def test_model_file(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "model.safetensors"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "models"
|
||||
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
180
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
180
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||
|
||||
|
||||
class TestNumberConvertExecute:
|
||||
@staticmethod
|
||||
def _exec(value) -> object:
|
||||
return NumberConvertNode.execute(value)
|
||||
|
||||
# --- INT input ---
|
||||
|
||||
def test_int_input(self):
|
||||
result = self._exec(42)
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_int_zero(self):
|
||||
result = self._exec(0)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
def test_int_negative(self):
|
||||
result = self._exec(-7)
|
||||
assert result[0] == -7.0
|
||||
assert result[1] == -7
|
||||
|
||||
# --- FLOAT input ---
|
||||
|
||||
def test_float_input(self):
|
||||
result = self._exec(3.14)
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_float_truncation_toward_zero(self):
|
||||
result = self._exec(-2.9)
|
||||
assert result[0] == -2.9
|
||||
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||
|
||||
def test_float_output_type(self):
|
||||
result = self._exec(5)
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
def test_int_output_type(self):
|
||||
result = self._exec(5.7)
|
||||
assert isinstance(result[1], int)
|
||||
|
||||
# --- BOOL input ---
|
||||
|
||||
def test_bool_true(self):
|
||||
result = self._exec(True)
|
||||
assert result[0] == 1.0
|
||||
assert result[1] == 1
|
||||
|
||||
def test_bool_false(self):
|
||||
result = self._exec(False)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
# --- STRING input ---
|
||||
|
||||
def test_string_integer(self):
|
||||
result = self._exec("42")
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_string_float(self):
|
||||
result = self._exec("3.14")
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative(self):
|
||||
result = self._exec("-5.5")
|
||||
assert result[0] == -5.5
|
||||
assert result[1] == -5
|
||||
|
||||
def test_string_with_whitespace(self):
|
||||
result = self._exec(" 7.0 ")
|
||||
assert result[0] == 7.0
|
||||
assert result[1] == 7
|
||||
|
||||
def test_string_scientific_notation(self):
|
||||
result = self._exec("1e3")
|
||||
assert result[0] == 1000.0
|
||||
assert result[1] == 1000
|
||||
|
||||
# --- Large number precision (string input) ---
|
||||
|
||||
def test_string_large_int_above_2_53(self):
|
||||
"""Text-to-int must not lose precision for integers beyond 2^53."""
|
||||
big = 2**53 + 1 # 9007199254740993
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_large_negative_int_above_2_53(self):
|
||||
big = -(2**53 + 1)
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_very_large_int(self):
|
||||
big = 2**63 + 42
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_large_int_float_output_is_float(self):
|
||||
"""FLOAT output is still a float (may lose precision, but must be float type)."""
|
||||
result = self._exec(str(2**53 + 1))
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
# --- Large number precision (int input) ---
|
||||
|
||||
def test_int_large_above_2_53(self):
|
||||
"""Native int input must preserve its value in the INT output."""
|
||||
big = 2**53 + 1
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
def test_int_large_negative_above_2_53(self):
|
||||
big = -(2**53 + 1)
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
def test_int_very_large(self):
|
||||
big = 2**100
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
# --- String decimal / scientific notation fallback ---
|
||||
|
||||
def test_string_decimal_still_truncates(self):
|
||||
"""Strings with decimal points fall back to int(float(...)) truncation."""
|
||||
result = self._exec("3.7")
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative_decimal_truncates(self):
|
||||
result = self._exec("-2.9")
|
||||
assert result[1] == -2
|
||||
|
||||
def test_string_scientific_large(self):
|
||||
result = self._exec("1e18")
|
||||
assert result[0] == 1e18
|
||||
assert result[1] == 10**18
|
||||
|
||||
# --- STRING error paths ---
|
||||
|
||||
def test_empty_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec("")
|
||||
|
||||
def test_whitespace_only_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec(" ")
|
||||
|
||||
def test_non_numeric_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||
self._exec("abc")
|
||||
|
||||
def test_string_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("inf")
|
||||
|
||||
def test_string_nan_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("nan")
|
||||
|
||||
def test_string_negative_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("-inf")
|
||||
|
||||
# --- Unsupported type ---
|
||||
|
||||
def test_unsupported_type_raises(self):
|
||||
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||
self._exec([1, 2, 3])
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -771,6 +772,188 @@ class TestSeederStopRestart:
|
||||
assert collected_roots[1] == ("input",)
|
||||
|
||||
|
||||
class TestEnqueueEnrichHandoff:
|
||||
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
|
||||
|
||||
def test_pending_enrich_runs_after_scan_completes(
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
"""A queued enrich request runs automatically when a scan finishes."""
|
||||
enrich_roots_seen: list[tuple] = []
|
||||
original_start = fresh_seeder.start
|
||||
|
||||
def tracking_start(*args, **kwargs):
|
||||
phase = kwargs.get("phase")
|
||||
roots = kwargs.get("roots", args[0] if args else None)
|
||||
result = original_start(*args, **kwargs)
|
||||
if phase == ScanPhase.ENRICH and result:
|
||||
enrich_roots_seen.append(roots)
|
||||
return result
|
||||
|
||||
fresh_seeder.start = tracking_start
|
||||
|
||||
# Start a fast scan, then enqueue an enrich while it's running
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
reached.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert reached.wait(timeout=2.0)
|
||||
|
||||
queued = fresh_seeder.enqueue_enrich(
|
||||
roots=("input",), compute_hashes=True
|
||||
)
|
||||
assert queued is False # queued, not started immediately
|
||||
|
||||
barrier.set()
|
||||
|
||||
# Wait for the original scan + the auto-started enrich scan
|
||||
deadline = time.monotonic() + 5.0
|
||||
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||
time.sleep(0.05)
|
||||
|
||||
assert enrich_roots_seen == [("input",)]
|
||||
|
||||
def test_enqueue_enrich_during_drain_does_not_lose_work(
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
"""enqueue_enrich called concurrently with drain cannot drop work.
|
||||
|
||||
Simulates the race: another thread calls enqueue_enrich right as the
|
||||
scan thread is draining _pending_enrich. The enqueue must either be
|
||||
picked up by the draining scan or successfully start its own scan.
|
||||
"""
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
enrich_started = threading.Event()
|
||||
|
||||
enrich_call_count = 0
|
||||
|
||||
def slow_collect(*args):
|
||||
reached.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
# Track how many times start_enrich actually fires
|
||||
real_start_enrich = fresh_seeder.start_enrich
|
||||
enrich_roots_seen: list[tuple] = []
|
||||
|
||||
def tracking_start_enrich(**kwargs):
|
||||
nonlocal enrich_call_count
|
||||
enrich_call_count += 1
|
||||
enrich_roots_seen.append(kwargs.get("roots"))
|
||||
result = real_start_enrich(**kwargs)
|
||||
if result:
|
||||
enrich_started.set()
|
||||
return result
|
||||
|
||||
fresh_seeder.start_enrich = tracking_start_enrich
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
# Start a scan
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert reached.wait(timeout=2.0)
|
||||
|
||||
# Queue an enrich while scan is running
|
||||
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
# Let scan finish — drain will fire start_enrich atomically
|
||||
barrier.set()
|
||||
|
||||
# Wait for drain to complete and the enrich scan to start
|
||||
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
|
||||
assert ("output",) in enrich_roots_seen
|
||||
|
||||
def test_concurrent_enqueue_during_drain_not_lost(
|
||||
self, fresh_seeder: _AssetSeeder,
|
||||
):
|
||||
"""A second enqueue_enrich arriving while drain is in progress is not lost.
|
||||
|
||||
Because the drain now holds _lock through the start_enrich call,
|
||||
a concurrent enqueue_enrich will block until start_enrich has
|
||||
transitioned state to RUNNING, then the enqueue will queue its
|
||||
payload as _pending_enrich for the *next* drain.
|
||||
"""
|
||||
scan_barrier = threading.Event()
|
||||
scan_reached = threading.Event()
|
||||
enrich_barrier = threading.Event()
|
||||
enrich_reached = threading.Event()
|
||||
|
||||
collect_call = 0
|
||||
|
||||
def gated_collect(*args):
|
||||
nonlocal collect_call
|
||||
collect_call += 1
|
||||
if collect_call == 1:
|
||||
# First call: the initial fast scan
|
||||
scan_reached.set()
|
||||
scan_barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
enrich_call = 0
|
||||
|
||||
def gated_get_unenriched(*args, **kwargs):
|
||||
nonlocal enrich_call
|
||||
enrich_call += 1
|
||||
if enrich_call == 1:
|
||||
# First enrich batch: signal and block
|
||||
enrich_reached.set()
|
||||
enrich_barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
):
|
||||
# 1. Start fast scan
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert scan_reached.wait(timeout=2.0)
|
||||
|
||||
# 2. Queue enrich while fast scan is running
|
||||
queued = fresh_seeder.enqueue_enrich(
|
||||
roots=("input",), compute_hashes=False
|
||||
)
|
||||
assert queued is False
|
||||
|
||||
# 3. Let the fast scan finish — drain will start the enrich scan
|
||||
scan_barrier.set()
|
||||
|
||||
# 4. Wait until the drained enrich scan is running
|
||||
assert enrich_reached.wait(timeout=5.0)
|
||||
|
||||
# 5. Now enqueue another enrich while the drained scan is running
|
||||
queued2 = fresh_seeder.enqueue_enrich(
|
||||
roots=("output",), compute_hashes=True
|
||||
)
|
||||
assert queued2 is False # should be queued, not started
|
||||
|
||||
# Verify _pending_enrich was set (the second enqueue was captured)
|
||||
with fresh_seeder._lock:
|
||||
assert fresh_seeder._pending_enrich is not None
|
||||
assert "output" in fresh_seeder._pending_enrich["roots"]
|
||||
|
||||
# Let the enrich scan finish
|
||||
enrich_barrier.set()
|
||||
|
||||
deadline = time.monotonic() + 5.0
|
||||
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||
return UnenrichedReferenceRow(
|
||||
reference_id=ref_id, asset_id=asset_id,
|
||||
|
||||
250
tests/test_asset_seeder.py
Normal file
250
tests/test_asset_seeder.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def seeder():
|
||||
"""Fresh seeder instance for each test."""
|
||||
return _AssetSeeder()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reset_to_idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetToIdle:
|
||||
def test_sets_idle_and_clears_progress(self, seeder):
|
||||
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||
seeder._state = State.RUNNING
|
||||
seeder._progress = progress
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is progress
|
||||
|
||||
def test_noop_when_progress_already_none(self, seeder):
|
||||
"""_reset_to_idle should handle None progress gracefully."""
|
||||
seeder._state = State.CANCELLING
|
||||
seeder._progress = None
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – immediate start when idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichStartsImmediately:
|
||||
def test_starts_when_idle(self, seeder):
|
||||
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||
|
||||
def test_no_pending_when_started_immediately(self, seeder):
|
||||
"""No pending request should be stored when start_enrich succeeds."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True):
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – queuing when busy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichQueuesWhenBusy:
|
||||
def test_queues_when_busy(self, seeder):
|
||||
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
|
||||
assert result is False
|
||||
assert seeder._pending_enrich == {
|
||||
"roots": ("models",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – merging when a pending request already exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichMergesPending:
|
||||
def _make_busy(self, seeder):
|
||||
"""Patch start_enrich to always return False (seeder busy)."""
|
||||
return patch.object(seeder, "start_enrich", return_value=False)
|
||||
|
||||
def test_merges_roots(self, seeder):
|
||||
"""A second enqueue should merge roots with the existing pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",))
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "output"}
|
||||
|
||||
def test_merges_overlapping_roots(self, seeder):
|
||||
"""Duplicate roots should be deduplicated."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models", "input"))
|
||||
seeder.enqueue_enrich(roots=("input", "output"))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
|
||||
def test_compute_hashes_sticky_true(self, seeder):
|
||||
"""Once compute_hashes is True it should stay True after merging."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_stays_false(self, seeder):
|
||||
"""If both enqueues have compute_hashes=False it stays False."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is False
|
||||
|
||||
def test_triple_merge(self, seeder):
|
||||
"""Three successive enqueues should all merge correctly."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending enrich drains after scan completes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPendingEnrichDrain:
|
||||
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||
"""After a fast scan finishes, the pending enrich should be started."""
|
||||
seeder = _AssetSeeder()
|
||||
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": True,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_called_once_with(
|
||||
roots=("output",),
|
||||
compute_hashes=True,
|
||||
)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||
seeder = _AssetSeeder()
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_no_drain_when_no_pending(self, *_mocks):
|
||||
"""start_enrich should not be called when there is no pending request."""
|
||||
seeder = _AssetSeeder()
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread-safety of enqueue_enrich
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichThreadSafety:
|
||||
def test_concurrent_enqueues(self, seeder):
|
||||
"""Multiple threads enqueuing should not lose roots."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
barrier = threading.Barrier(3)
|
||||
|
||||
def enqueue(root):
|
||||
barrier.wait()
|
||||
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=enqueue, args=(r,))
|
||||
for r in ("models", "input", "output")
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
@@ -24,6 +24,7 @@ def init_mime_types():
|
||||
# Web types (used by server.py for static file serving)
|
||||
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
||||
mimetypes.add_type('image/webp', '.webp')
|
||||
mimetypes.add_type('image/svg+xml', '.svg')
|
||||
|
||||
# Model and data file types (used by asset scanning / metadata extraction)
|
||||
mimetypes.add_type("application/safetensors", ".safetensors")
|
||||
|
||||
Reference in New Issue
Block a user