mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-27 07:57:31 +00:00
Compare commits
25 Commits
release/v0
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
225c52f6a4 | ||
|
|
b1fdbeb9a7 | ||
|
|
1dc64f3526 | ||
|
|
359559c913 | ||
|
|
8165485a17 | ||
|
|
b0fd65e884 | ||
|
|
2a1f402601 | ||
|
|
3eba2dcf2d | ||
|
|
404d7b9978 | ||
|
|
6580a6bc01 | ||
|
|
3b15651bc6 | ||
|
|
a55835f10c | ||
|
|
b53b10ea61 | ||
|
|
7d5534d8e5 | ||
|
|
5ebb0c2e0b | ||
|
|
a0a64c679f | ||
|
|
8e73678dae | ||
|
|
c2862b24af | ||
|
|
f9ec85f739 | ||
|
|
2d5fd3f5dd | ||
|
|
2d4970ff67 | ||
|
|
e87858e974 | ||
|
|
da6edb5a4e | ||
|
|
6265a239f3 | ||
|
|
d49420b3c7 |
@@ -1,6 +1,7 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
create_stub_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
reassign_asset_references,
|
||||
@@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
||||
UnenrichedReferenceRow,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_update_enrichment_level,
|
||||
count_active_siblings,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
convert_metadata_to_rows,
|
||||
@@ -80,6 +82,8 @@ __all__ = [
|
||||
"bulk_insert_references_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_enrichment_level",
|
||||
"count_active_siblings",
|
||||
"create_stub_asset",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"convert_metadata_to_rows",
|
||||
|
||||
@@ -78,6 +78,18 @@ def upsert_asset(
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def create_stub_asset(
|
||||
session: Session,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> Asset:
|
||||
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
|
||||
@@ -114,6 +114,23 @@ def get_reference_by_file_path(
|
||||
)
|
||||
|
||||
|
||||
def count_active_siblings(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
exclude_reference_id: str,
|
||||
) -> int:
|
||||
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||
return (
|
||||
session.query(AssetReference)
|
||||
.filter(
|
||||
AssetReference.asset_id == asset_id,
|
||||
AssetReference.id != exclude_reference_id,
|
||||
AssetReference.deleted_at.is_(None),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
def reference_exists_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
||||
delete_references_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_id,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
mark_references_missing_outside_prefixes,
|
||||
@@ -338,6 +339,7 @@ def build_asset_specs(
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
"job_id": None,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
@@ -426,6 +428,7 @@ def enrich_asset(
|
||||
except OSError:
|
||||
return new_level
|
||||
|
||||
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
@@ -489,6 +492,18 @@ def enrich_asset(
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||
# started (e.g. ingest_existing_file updated it), our results are
|
||||
# stale — discard them to avoid overwriting fresh registration data.
|
||||
ref = get_reference_by_id(session, reference_id)
|
||||
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||
session.rollback()
|
||||
logging.info(
|
||||
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||
reference_id,
|
||||
)
|
||||
return ENRICHMENT_STUB
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
@@ -77,7 +77,9 @@ class _AssetSeeder:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
# RLock is required because _run_scan() drains pending work while
|
||||
# holding _lock and re-enters start() which also acquires _lock.
|
||||
self._lock = threading.RLock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._last_progress: Progress | None = None
|
||||
@@ -92,6 +94,7 @@ class _AssetSeeder:
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
self._pending_enrich: dict | None = None
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
@@ -196,6 +199,42 @@ class _AssetSeeder:
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def enqueue_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||
|
||||
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||
request is stored and will run automatically when the current scan
|
||||
finishes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if started immediately, False if queued for later
|
||||
"""
|
||||
with self._lock:
|
||||
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||
return True
|
||||
if self._pending_enrich is not None:
|
||||
existing_roots = set(self._pending_enrich["roots"])
|
||||
existing_roots.update(roots)
|
||||
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||
self._pending_enrich["compute_hashes"] = (
|
||||
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||
)
|
||||
else:
|
||||
self._pending_enrich = {
|
||||
"roots": roots,
|
||||
"compute_hashes": compute_hashes,
|
||||
}
|
||||
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||
return False
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
@@ -381,9 +420,13 @@ class _AssetSeeder:
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
|
||||
def _reset_to_idle(self) -> None:
|
||||
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
@@ -594,9 +637,18 @@ class _AssetSeeder:
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
pending = self._pending_enrich
|
||||
if pending is not None:
|
||||
self._pending_enrich = None
|
||||
if not self.start_enrich(
|
||||
roots=pending["roots"],
|
||||
compute_hashes=pending["compute_hashes"],
|
||||
):
|
||||
logging.warning(
|
||||
"Pending enrich scan could not start (roots=%s)",
|
||||
pending["roots"],
|
||||
)
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
@@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
ingest_existing_file,
|
||||
register_output_files,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
@@ -72,6 +74,8 @@ __all__ = [
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"ingest_existing_file",
|
||||
"register_output_files",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
|
||||
@@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
job_id: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
@@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
|
||||
name: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
job_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
@@ -167,6 +169,7 @@ def batch_insert_seed_assets(
|
||||
"name": spec["info_name"],
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"job_id": spec.get("job_id"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
|
||||
@@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
count_active_siblings,
|
||||
create_stub_asset,
|
||||
ensure_tags_exist,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_file_path,
|
||||
@@ -23,7 +26,8 @@ from app.assets.database.queries import (
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.helpers import get_utc_now, normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
@@ -130,6 +134,102 @@ def _ingest_file_from_path(
|
||||
)
|
||||
|
||||
|
||||
def register_output_files(
|
||||
file_paths: Sequence[str],
|
||||
user_metadata: UserMetadata = None,
|
||||
job_id: str | None = None,
|
||||
) -> int:
|
||||
"""Register a batch of output file paths as assets.
|
||||
|
||||
Returns the number of files successfully registered.
|
||||
"""
|
||||
registered = 0
|
||||
for abs_path in file_paths:
|
||||
if not os.path.isfile(abs_path):
|
||||
continue
|
||||
try:
|
||||
if ingest_existing_file(
|
||||
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||
):
|
||||
registered += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to register output: %s", abs_path)
|
||||
return registered
|
||||
|
||||
|
||||
def ingest_existing_file(
|
||||
abs_path: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
extra_tags: Sequence[str] = (),
|
||||
owner_id: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Register an existing on-disk file as an asset stub.
|
||||
|
||||
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||
|
||||
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||
UX visibility.
|
||||
|
||||
Returns True if a row was inserted or updated, False otherwise.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||
|
||||
with create_session() as session:
|
||||
existing_ref = get_reference_by_file_path(session, locator)
|
||||
if existing_ref is not None:
|
||||
now = get_utc_now()
|
||||
existing_ref.mtime_ns = mtime_ns
|
||||
existing_ref.job_id = job_id
|
||||
existing_ref.is_missing = False
|
||||
existing_ref.deleted_at = None
|
||||
existing_ref.updated_at = now
|
||||
existing_ref.enrichment_level = 0
|
||||
|
||||
asset = existing_ref.asset
|
||||
if asset:
|
||||
# If other refs share this asset, detach to a new stub
|
||||
# instead of mutating the shared row.
|
||||
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||
if siblings > 0:
|
||||
new_asset = create_stub_asset(
|
||||
session,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type or asset.mime_type,
|
||||
)
|
||||
existing_ref.asset_id = new_asset.id
|
||||
else:
|
||||
asset.hash = None
|
||||
asset.size_bytes = size_bytes
|
||||
if mime_type:
|
||||
asset.mime_type = mime_type
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
spec = {
|
||||
"abs_path": abs_path,
|
||||
"size_bytes": size_bytes,
|
||||
"mtime_ns": mtime_ns,
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": os.path.basename(abs_path),
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": mime_type,
|
||||
"job_id": job_id,
|
||||
}
|
||||
if tags:
|
||||
ensure_tags_exist(session, tags)
|
||||
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||
session.commit()
|
||||
return result.won_paths > 0
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
|
||||
@@ -93,12 +93,13 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "models"], str]:
|
||||
) -> tuple[Literal["input", "output", "temp", "models"], str]:
|
||||
"""Determine which root category a file path belongs to.
|
||||
|
||||
Categories:
|
||||
- 'input': under folder_paths.get_input_directory()
|
||||
- 'output': under folder_paths.get_output_directory()
|
||||
- 'temp': under folder_paths.get_temp_directory()
|
||||
- 'models': under any base path from get_comfy_models_folders()
|
||||
|
||||
Returns:
|
||||
@@ -129,7 +130,12 @@ def get_asset_category_and_relative_path(
|
||||
if _check_is_within(fp_abs, output_base):
|
||||
return "output", _compute_relative(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
# 3) temp
|
||||
temp_base = os.path.abspath(folder_paths.get_temp_directory())
|
||||
if _check_is_within(fp_abs, temp_base):
|
||||
return "temp", _compute_relative(fp_abs, temp_base)
|
||||
|
||||
# 4) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
@@ -146,7 +152,7 @@ def get_asset_category_and_relative_path(
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
||||
f"Path is not within input, output, temp, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
@@ -0,0 +1,90 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform float u_float0;
|
||||
uniform float u_float1;
|
||||
uniform float u_float2;
|
||||
uniform float u_float3;
|
||||
uniform float u_float4;
|
||||
uniform float u_float5;
|
||||
uniform float u_float6;
|
||||
uniform float u_float7;
|
||||
uniform float u_float8;
|
||||
uniform bool u_bool0;
|
||||
|
||||
in vec2 v_texCoord;
|
||||
out vec4 fragColor;
|
||||
|
||||
vec3 rgb2hsl(vec3 c) {
|
||||
float maxC = max(c.r, max(c.g, c.b));
|
||||
float minC = min(c.r, min(c.g, c.b));
|
||||
float l = (maxC + minC) * 0.5;
|
||||
if (maxC == minC) return vec3(0.0, 0.0, l);
|
||||
float d = maxC - minC;
|
||||
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
|
||||
float h;
|
||||
if (maxC == c.r) {
|
||||
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
|
||||
} else if (maxC == c.g) {
|
||||
h = (c.b - c.r) / d + 2.0;
|
||||
} else {
|
||||
h = (c.r - c.g) / d + 4.0;
|
||||
}
|
||||
h /= 6.0;
|
||||
return vec3(h, s, l);
|
||||
}
|
||||
|
||||
float hue2rgb(float p, float q, float t) {
|
||||
if (t < 0.0) t += 1.0;
|
||||
if (t > 1.0) t -= 1.0;
|
||||
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
|
||||
if (t < 1.0 / 2.0) return q;
|
||||
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
|
||||
return p;
|
||||
}
|
||||
|
||||
vec3 hsl2rgb(vec3 hsl) {
|
||||
float h = hsl.x, s = hsl.y, l = hsl.z;
|
||||
if (s == 0.0) return vec3(l);
|
||||
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
|
||||
float p = 2.0 * l - q;
|
||||
return vec3(
|
||||
hue2rgb(p, q, h + 1.0 / 3.0),
|
||||
hue2rgb(p, q, h),
|
||||
hue2rgb(p, q, h - 1.0 / 3.0)
|
||||
);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 tex = texture(u_image0, v_texCoord);
|
||||
vec3 color = tex.rgb;
|
||||
|
||||
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
|
||||
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
|
||||
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
|
||||
|
||||
float maxC = max(color.r, max(color.g, color.b));
|
||||
float minC = min(color.r, min(color.g, color.b));
|
||||
float lightness = (maxC + minC) * 0.5;
|
||||
|
||||
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
|
||||
const float a = 0.25;
|
||||
const float b = 0.333;
|
||||
const float scale = 0.7;
|
||||
|
||||
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
|
||||
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
|
||||
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
|
||||
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
|
||||
|
||||
color += sw * shadows + mw * midtones + hw * highlights;
|
||||
|
||||
if (u_bool0) {
|
||||
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
|
||||
hsl.z = lightness;
|
||||
color = hsl2rgb(hsl);
|
||||
}
|
||||
|
||||
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||
}
|
||||
46
blueprints/.glsl/Color_Curves_8.frag
Normal file
46
blueprints/.glsl/Color_Curves_8.frag
Normal file
@@ -0,0 +1,46 @@
|
||||
#version 300 es
|
||||
precision highp float;
|
||||
|
||||
uniform sampler2D u_image0;
|
||||
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
|
||||
uniform sampler2D u_curve1; // Red channel curve
|
||||
uniform sampler2D u_curve2; // Green channel curve
|
||||
uniform sampler2D u_curve3; // Blue channel curve
|
||||
|
||||
in vec2 v_texCoord;
|
||||
layout(location = 0) out vec4 fragColor0;
|
||||
|
||||
// GIMP-compatible curve lookup with manual linear interpolation.
|
||||
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
|
||||
// index = value * (n_samples - 1)
|
||||
// f = fract(index)
|
||||
// result = (1-f) * samples[floor] + f * samples[ceil]
|
||||
//
|
||||
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
|
||||
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
|
||||
float applyCurve(sampler2D curve, float value) {
|
||||
value = clamp(value, 0.0, 1.0);
|
||||
|
||||
float pos = value * 255.0;
|
||||
int lo = int(floor(pos));
|
||||
int hi = min(lo + 1, 255);
|
||||
float f = pos - float(lo);
|
||||
|
||||
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
|
||||
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
|
||||
|
||||
return a + f * (b - a);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 color = texture(u_image0, v_texCoord);
|
||||
|
||||
// GIMP order: per-channel curves first, then RGB master curve.
|
||||
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
|
||||
// dest = colors_curve( channel_curve( src ) )
|
||||
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r));
|
||||
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g));
|
||||
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b));
|
||||
|
||||
fragColor0 = vec4(color.rgb, color.a);
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
1
blueprints/Color Balance.json
Normal file
1
blueprints/Color Balance.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Curves.json
Normal file
1
blueprints/Color Curves.json
Normal file
File diff suppressed because one or more lines are too long
@@ -386,7 +386,7 @@ class Flux(nn.Module):
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
|
||||
@@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
|
||||
# Inject reference audio for ID-LoRA in-context conditioning
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
ref_audio_seq_len = 0
|
||||
if ref_audio is not None:
|
||||
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
|
||||
if ref_tokens.shape[0] < ax.shape[0]:
|
||||
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
|
||||
ref_audio_seq_len = ref_tokens.shape[1]
|
||||
B = ax.shape[0]
|
||||
|
||||
# Compute negative temporal positions matching ID-LoRA convention:
|
||||
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
|
||||
p = self.a_patchifier
|
||||
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
|
||||
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
|
||||
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
|
||||
time_offset = ref_end[-1].item() + tpl
|
||||
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
|
||||
|
||||
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
||||
additional_args["target_audio_seq_len"] = ax.shape[1]
|
||||
ax = torch.cat([ref_tokens, ax], dim=1)
|
||||
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
||||
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
@@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||
if ref_audio_seq_len > 0 and a_timestep is not None:
|
||||
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
||||
target_len = kwargs.get("target_audio_seq_len")
|
||||
if a_timestep.dim() <= 1:
|
||||
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
||||
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
|
||||
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
||||
if a_timestep is not None:
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
@@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Trim reference audio tokens before unpatchification
|
||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||
if ref_audio_seq_len > 0:
|
||||
ax = ax[:, ref_audio_seq_len:]
|
||||
if a_embedded_timestep.shape[1] > 1:
|
||||
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
@@ -937,9 +937,10 @@ class LongCatImage(Flux):
|
||||
transformer_options = transformer_options.copy()
|
||||
rope_opts = transformer_options.get("rope_options", {})
|
||||
rope_opts = dict(rope_opts)
|
||||
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
|
||||
rope_opts.setdefault("shift_t", 1.0)
|
||||
rope_opts.setdefault("shift_y", 512.0)
|
||||
rope_opts.setdefault("shift_x", 512.0)
|
||||
rope_opts.setdefault("shift_y", pe_len)
|
||||
rope_opts.setdefault("shift_x", pe_len)
|
||||
transformer_options["rope_options"] = rope_opts
|
||||
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
@@ -1060,6 +1061,10 @@ class LTXAV(BaseModel):
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
if ref_audio is not None:
|
||||
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
|
||||
@@ -55,6 +55,7 @@ total_vram = 0
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
training_fp8_bwd = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
|
||||
69
comfy/ops.py
69
comfy/ops.py
@@ -777,8 +777,16 @@ from .quant_ops import (
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||
|
||||
When training_fp8_bwd is enabled:
|
||||
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||
- Cached input is FP8 (half the memory of bf16)
|
||||
|
||||
When training_fp8_bwd is disabled:
|
||||
- Forward: quantize input per layout, use quantized matmul
|
||||
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input (same as inference path)
|
||||
# Quantize input for forward (same layout as weight)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
@@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Restore original input shape
|
||||
# Unflatten output to match original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
# Save for backward
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||
|
||||
if ctx.fp8_bwd:
|
||||
# Cache FP8 quantized input — half the memory of bf16
|
||||
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||
ctx.q_input = q_input # already FP8, reuse
|
||||
else:
|
||||
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||
ctx.save_for_backward(weight)
|
||||
else:
|
||||
ctx.q_input = None
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
# Value casting — only difference between fp8 and non-fp8 paths
|
||||
if ctx.fp8_bwd:
|
||||
weight, = ctx.saved_tensors
|
||||
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||
weight_mm = weight
|
||||
elif isinstance(weight, QuantizedTensor):
|
||||
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
else:
|
||||
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
input_mm = ctx.q_input
|
||||
else:
|
||||
weight_f = weight.to(compute_dtype)
|
||||
input_float, weight = ctx.saved_tensors
|
||||
# Standard tensors → torch.mm does regular matmul
|
||||
grad_mm = grad_2d
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_mm = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_mm = weight.to(compute_dtype)
|
||||
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
# Computation — same for both paths, dispatch handles the rest
|
||||
grad_input = torch.mm(grad_mm, weight_mm)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
@@ -895,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
self.weight = None
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
@@ -1001,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
|
||||
33
comfy/sd.py
33
comfy/sd.py
@@ -61,6 +61,7 @@ import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@@ -425,13 +426,13 @@ class CLIP:
|
||||
def get_key_patches(self):
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@@ -1228,6 +1229,11 @@ class TEModel(Enum):
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
GEMMA_3_4B_VISION = 22
|
||||
QWEN35_08B = 23
|
||||
QWEN35_2B = 24
|
||||
QWEN35_4B = 25
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@@ -1267,6 +1273,17 @@ def detect_te_model(sd):
|
||||
return TEModel.QWEN25_3B
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
|
||||
weight = sd['model.language_model.layers.0.input_layernorm.weight']
|
||||
if weight.shape[0] == 1024:
|
||||
return TEModel.QWEN35_08B
|
||||
if weight.shape[0] == 2560:
|
||||
return TEModel.QWEN35_4B
|
||||
if weight.shape[0] == 4096:
|
||||
return TEModel.QWEN35_9B
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@@ -1299,11 +1316,12 @@ def t5xxl_detect(clip_data):
|
||||
return {}
|
||||
|
||||
def llama_detect(clip_data):
|
||||
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
||||
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
|
||||
|
||||
for sd in clip_data:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||
for weight_name in weight_names:
|
||||
if weight_name in sd:
|
||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||
|
||||
return {}
|
||||
|
||||
@@ -1431,6 +1449,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
|
||||
@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
||||
if isinstance(tokens, dict):
|
||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||
else:
|
||||
tokens_only = tokens
|
||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
@@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
lm_head: bool = True
|
||||
stop_tokens = [151643, 151645]
|
||||
|
||||
@dataclass
|
||||
@@ -655,6 +655,17 @@ class Llama2_(nn.Module):
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
return past_key_values[0][2]
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
return precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
@@ -667,17 +678,12 @@ class Llama2_(nn.Module):
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
past_len = past_key_values[0][2]
|
||||
past_len = self.get_past_len(past_key_values)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
self.config.rope_dims,
|
||||
device=x.device)
|
||||
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
@@ -812,9 +818,16 @@ class BaseGenerate:
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
|
||||
device = embeds.device
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
model_config = self.model.config
|
||||
past_key_values = []
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
stop_tokens = self.model.config.stop_tokens
|
||||
@@ -829,11 +842,8 @@ class BaseGenerate:
|
||||
if embeds.ndim == 2:
|
||||
embeds = embeds.unsqueeze(0)
|
||||
|
||||
past_key_values = [] #kv_cache init
|
||||
max_cache_len = embeds.shape[1] + max_length
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||
|
||||
@@ -844,7 +854,7 @@ class BaseGenerate:
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
@@ -856,7 +866,7 @@ class BaseGenerate:
|
||||
|
||||
return generated_token_ids
|
||||
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
|
||||
|
||||
if not do_sample or temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
@@ -867,6 +877,11 @@ class BaseGenerate:
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if presence_penalty is not None and presence_penalty != 0.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] -= presence_penalty
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
@@ -897,6 +912,9 @@ class BaseGenerate:
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
if self.model.config.lm_head:
|
||||
return self.model.lm_head(input)
|
||||
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
@@ -1028,12 +1046,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
if attention_mask is not None:
|
||||
# Assign compact sequential positions to attended tokens only,
|
||||
# skipping over padding so post-padding tokens aren't inflated.
|
||||
after_mask = attention_mask[0, end:]
|
||||
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
|
||||
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
|
||||
else:
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
|
||||
@@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
|
||||
return [output]
|
||||
|
||||
|
||||
IMAGE_PAD_TOKEN_ID = 151655
|
||||
|
||||
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
||||
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory,
|
||||
@@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
name="qwen25_7b",
|
||||
tokenizer=LongCatImageBaseTokenizer,
|
||||
)
|
||||
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
||||
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
|
||||
skip_template = False
|
||||
if text.startswith("<|im_start|>"):
|
||||
skip_template = True
|
||||
@@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
|
||||
)
|
||||
else:
|
||||
has_images = images is not None and len(images) > 0
|
||||
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
|
||||
|
||||
prefix_ids = base_tok.tokenizer(
|
||||
self.longcat_template_prefix, add_special_tokens=False
|
||||
template_prefix, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
suffix_ids = base_tok.tokenizer(
|
||||
self.longcat_template_suffix, add_special_tokens=False
|
||||
self.SUFFIX, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
|
||||
prompt_tokens = base_tok.tokenize_with_weights(
|
||||
@@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
suffix_pairs = [(t, 1.0) for t in suffix_ids]
|
||||
|
||||
combined = prefix_pairs + prompt_pairs + suffix_pairs
|
||||
|
||||
if has_images:
|
||||
embed_count = 0
|
||||
for i in range(len(combined)):
|
||||
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
|
||||
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
|
||||
embed_count += 1
|
||||
|
||||
tokens = {"qwen25_7b": [combined]}
|
||||
|
||||
return tokens
|
||||
|
||||
@@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
self.dtypes.add(dtype)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||
@@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
|
||||
833
comfy/text_encoders/qwen35.py
Normal file
833
comfy/text_encoders/qwen35.py
Normal file
@@ -0,0 +1,833 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
import math
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.qwen_vl
|
||||
|
||||
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
||||
|
||||
|
||||
def _qwen35_layer_types(n):
|
||||
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
||||
|
||||
@dataclass
|
||||
class Qwen35Config:
|
||||
vocab_size: int = 248320
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 24
|
||||
# Full attention params
|
||||
num_attention_heads: int = 8
|
||||
num_key_value_heads: int = 2
|
||||
head_dim: int = 256
|
||||
partial_rotary_factor: float = 0.25
|
||||
# Linear attention (DeltaNet) params
|
||||
linear_num_key_heads: int = 16
|
||||
linear_num_value_heads: int = 16
|
||||
linear_key_head_dim: int = 128
|
||||
linear_value_head_dim: int = 128
|
||||
conv_kernel_size: int = 4
|
||||
# Shared params
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000000.0
|
||||
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
|
||||
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
|
||||
rms_norm_add: bool = True
|
||||
mlp_activation: str = "silu"
|
||||
qkv_bias: bool = False
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
|
||||
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
|
||||
transformer_type: str = "qwen35_2b"
|
||||
rope_dims: list = None
|
||||
rope_scale: float = None
|
||||
|
||||
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
|
||||
|
||||
QWEN35_MODELS = {
|
||||
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
|
||||
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
|
||||
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
|
||||
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||
}
|
||||
|
||||
|
||||
def _make_config(model_type, config_dict={}):
|
||||
overrides = QWEN35_MODELS.get(model_type, {}).copy()
|
||||
overrides.pop("vision", None)
|
||||
if "num_hidden_layers" in overrides:
|
||||
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
|
||||
overrides.update(config_dict)
|
||||
return Qwen35Config(**overrides)
|
||||
|
||||
|
||||
class RMSNormGated(RMSNorm):
|
||||
def forward(self, x, gate):
|
||||
return super().forward(x) * F.silu(gate.to(x.dtype))
|
||||
|
||||
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
|
||||
initial_dtype = query.dtype
|
||||
query = F.normalize(query, dim=-1)
|
||||
key = F.normalize(key, dim=-1)
|
||||
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size))
|
||||
key = F.pad(key, (0, 0, 0, pad_size))
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
total_sequence_length = sequence_length + pad_size
|
||||
scale = 1 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
||||
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
last_recurrent_state = (
|
||||
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
||||
if initial_state is None
|
||||
else initial_state.to(value)
|
||||
)
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
||||
|
||||
for i in range(0, total_sequence_length // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
||||
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :sequence_length]
|
||||
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
|
||||
|
||||
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
|
||||
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
|
||||
# weight: [channels, kernel_size]
|
||||
state_len = conv_state.shape[-1]
|
||||
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
|
||||
conv_state.copy_(combined[:, :, -state_len:])
|
||||
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
|
||||
if bias is not None:
|
||||
out = out + bias.unsqueeze(0).unsqueeze(-1)
|
||||
return F.silu(out).to(x.dtype)
|
||||
|
||||
|
||||
# GatedDeltaNet - Linear Attention Layer
|
||||
|
||||
class GatedDeltaNet(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
hidden = config.hidden_size
|
||||
self.num_key_heads = config.linear_num_key_heads
|
||||
self.num_value_heads = config.linear_num_value_heads
|
||||
self.key_head_dim = config.linear_key_head_dim
|
||||
self.value_head_dim = config.linear_value_head_dim
|
||||
self.conv_kernel_size = config.conv_kernel_size
|
||||
|
||||
key_dim = self.num_key_heads * self.key_head_dim
|
||||
value_dim = self.num_value_heads * self.value_head_dim
|
||||
self.key_dim = key_dim
|
||||
self.value_dim = value_dim
|
||||
conv_dim = key_dim * 2 + value_dim
|
||||
|
||||
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||
|
||||
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
|
||||
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
|
||||
|
||||
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, past_key_value=None, **kwargs):
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
use_recurrent = (
|
||||
past_key_value is not None
|
||||
and past_key_value[2] > 0
|
||||
and seq_len == 1
|
||||
)
|
||||
|
||||
# Projections (shared)
|
||||
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
|
||||
z = self.in_proj_z(x)
|
||||
b = self.in_proj_b(x)
|
||||
a = self.in_proj_a(x)
|
||||
|
||||
# Conv1d
|
||||
if use_recurrent:
|
||||
recurrent_state, conv_state, step_index = past_key_value
|
||||
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
|
||||
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
|
||||
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
|
||||
else:
|
||||
if past_key_value is not None:
|
||||
recurrent_state, conv_state, step_index = past_key_value
|
||||
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
||||
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
|
||||
# Split QKV and compute beta/g
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
|
||||
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
|
||||
beta = b.sigmoid()
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
|
||||
|
||||
# Delta rule
|
||||
if use_recurrent:
|
||||
# single-token path: work in [B, heads, dim] without seq dim
|
||||
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
|
||||
|
||||
if self.num_value_heads != self.num_key_heads:
|
||||
rep = self.num_value_heads // self.num_key_heads
|
||||
query = query.repeat_interleave(rep, dim=1)
|
||||
key = key.repeat_interleave(rep, dim=1)
|
||||
|
||||
scale = self.key_head_dim ** -0.5
|
||||
q = F.normalize(query.float(), dim=-1) * scale
|
||||
k = F.normalize(key.float(), dim=-1)
|
||||
v = value.float()
|
||||
beta_t = beta.reshape(batch_size, -1)
|
||||
g_t = g.reshape(batch_size, -1).exp()
|
||||
|
||||
# In-place state update: [B, heads, k_dim, v_dim]
|
||||
recurrent_state.mul_(g_t[:, :, None, None])
|
||||
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
|
||||
delta = (v - kv_mem) * beta_t[:, :, None]
|
||||
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
|
||||
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
|
||||
|
||||
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
|
||||
present_key_value = (recurrent_state, conv_state, step_index + 1)
|
||||
else:
|
||||
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
|
||||
|
||||
if self.num_value_heads != self.num_key_heads:
|
||||
rep = self.num_value_heads // self.num_key_heads
|
||||
query = query.repeat_interleave(rep, dim=2)
|
||||
key = key.repeat_interleave(rep, dim=2)
|
||||
|
||||
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||
query, key, value, g=g, beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=past_key_value is not None,
|
||||
)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
if last_recurrent_state is not None:
|
||||
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
|
||||
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
|
||||
|
||||
# Gated norm + output projection (shared)
|
||||
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
|
||||
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
|
||||
return output, present_key_value
|
||||
|
||||
|
||||
# GatedAttention - Full Attention with output gating
|
||||
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
|
||||
"""Compute RoPE frequencies for partial rotary embeddings."""
|
||||
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
if mrope_section is not None and position_ids.shape[0] == 3:
|
||||
mrope_section_2 = [s * 2 for s in mrope_section]
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
|
||||
|
||||
|
||||
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
|
||||
"""Apply RoPE to only the first rotary_dim dimensions."""
|
||||
xq_rot = xq[..., :rotary_dim]
|
||||
xq_pass = xq[..., rotary_dim:]
|
||||
xk_rot = xk[..., :rotary_dim]
|
||||
xk_pass = xk[..., rotary_dim:]
|
||||
|
||||
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
|
||||
|
||||
xq = torch.cat([xq_rot, xq_pass], dim=-1)
|
||||
xk = torch.cat([xk_rot, xk_pass], dim=-1)
|
||||
return xq, xk
|
||||
|
||||
|
||||
class GatedAttention(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.hidden_size = config.hidden_size
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
||||
|
||||
# q_proj outputs 2x: query + gate
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
# QK norms with (1+weight) scaling
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||
batch_size, seq_length, _ = x.shape
|
||||
|
||||
# Project Q (with gate), K, V
|
||||
qg = self.q_proj(x)
|
||||
# Split into query and gate: each is [B, seq, inner_size]
|
||||
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
|
||||
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
|
||||
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
|
||||
|
||||
xk = self.k_proj(x)
|
||||
xv = self.v_proj(x)
|
||||
|
||||
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
|
||||
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
|
||||
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply partial RoPE
|
||||
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
|
||||
|
||||
# KV cache
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
past_key, past_value, index = past_key_value
|
||||
num_tokens = xk.shape[2]
|
||||
if past_key.shape[2] >= (index + num_tokens):
|
||||
past_key[:, :, index:index + num_tokens] = xk
|
||||
past_value[:, :, index:index + num_tokens] = xv
|
||||
xk = past_key[:, :, :index + num_tokens]
|
||||
xv = past_value[:, :, :index + num_tokens]
|
||||
present_key_value = (past_key, past_value, index + num_tokens)
|
||||
else:
|
||||
if index > 0:
|
||||
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
# Expand KV heads for GQA
|
||||
if self.num_heads != self.num_kv_heads:
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
output = output * gate.sigmoid()
|
||||
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
|
||||
# Hybrid Transformer Block
|
||||
class Qwen35TransformerBlock(nn.Module):
|
||||
def __init__(self, config, index, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.layer_type = config.layer_types[index]
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
|
||||
else:
|
||||
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||
if self.layer_type == "linear_attention":
|
||||
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
|
||||
else:
|
||||
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
|
||||
|
||||
x = x + h
|
||||
x = x + self.mlp(self.post_attention_layernorm(x))
|
||||
return x, present_key_value
|
||||
|
||||
|
||||
# Qwen35 Transformer Backbone
|
||||
class Qwen35Transformer(Llama2_):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.normalize_in = False
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def get_past_len(self, past_key_values):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if layer.layer_type == "full_attention":
|
||||
if len(past_key_values) > i:
|
||||
return past_key_values[i][2]
|
||||
break
|
||||
return 0
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
|
||||
return precompute_partial_rope(
|
||||
self.config.head_dim, rotary_dim, position_ids,
|
||||
self.config.rope_theta, device=device,
|
||||
mrope_section=self.config.mrope_section,
|
||||
)
|
||||
|
||||
|
||||
# Vision Encoder
|
||||
class Qwen35VisionPatchEmbed(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.patch_size = config["patch_size"]
|
||||
self.temporal_patch_size = config["temporal_patch_size"]
|
||||
self.in_channels = config["in_channels"]
|
||||
self.embed_dim = config["hidden_size"]
|
||||
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
||||
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
target_dtype = self.proj.weight.dtype
|
||||
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
|
||||
|
||||
|
||||
class Qwen35VisionMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
|
||||
|
||||
|
||||
class Qwen35VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta=10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, seqlen):
|
||||
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class Qwen35VisionAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||
seq_length = x.shape[0]
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
|
||||
|
||||
# Process per-sequence attention
|
||||
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
q_splits = torch.split(query_states, lengths, dim=0)
|
||||
k_splits = torch.split(key_states, lengths, dim=0)
|
||||
v_splits = torch.split(value_states, lengths, dim=0)
|
||||
|
||||
attn_outputs = []
|
||||
for q, k, v in zip(q_splits, k_splits, v_splits):
|
||||
q = q.transpose(0, 1).unsqueeze(0)
|
||||
k = k.transpose(0, 1).unsqueeze(0)
|
||||
v = v.transpose(0, 1).unsqueeze(0)
|
||||
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
||||
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
return self.proj(attn_output)
|
||||
|
||||
|
||||
class Qwen35VisionBlock(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class Qwen35VisionPatchMerger(nn.Module):
|
||||
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
|
||||
merge_dim = hidden_size * (spatial_merge_size ** 2)
|
||||
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
|
||||
self.merge_dim = merge_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x).view(-1, self.merge_dim)
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
||||
|
||||
|
||||
class Qwen35VisionModel(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = config["spatial_merge_size"]
|
||||
self.patch_size = config["patch_size"]
|
||||
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
self.hidden_size = config["hidden_size"]
|
||||
self.num_heads = config["num_heads"]
|
||||
self.num_position_embeddings = config["num_position_embeddings"]
|
||||
|
||||
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
|
||||
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
|
||||
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
|
||||
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config["depth"])
|
||||
])
|
||||
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
merge_size = self.spatial_merge_size
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||
freq_table = self.rotary_pos_emb(max_hw)
|
||||
device = freq_table.device
|
||||
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
|
||||
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||
offset = 0
|
||||
for num_frames, height, width in grid_thw_list:
|
||||
num_frames, height, width = int(num_frames), int(height), int(width)
|
||||
merged_h, merged_w = height // merge_size, width // merge_size
|
||||
block_rows = torch.arange(merged_h, device=device)
|
||||
block_cols = torch.arange(merged_w, device=device)
|
||||
intra_row = torch.arange(merge_size, device=device)
|
||||
intra_col = torch.arange(merge_size, device=device)
|
||||
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
||||
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
||||
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||
if num_frames > 1:
|
||||
coords = coords.repeat(num_frames, 1)
|
||||
num_tokens = coords.shape[0]
|
||||
pos_ids[offset:offset + num_tokens] = coords
|
||||
offset += num_tokens
|
||||
embeddings = freq_table[pos_ids]
|
||||
embeddings = embeddings.flatten(1)
|
||||
return embeddings
|
||||
|
||||
def fast_pos_embed_interpolate(self, grid_thw):
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_ts = [int(row[0]) for row in grid_thw_list]
|
||||
grid_hs = [int(row[1]) for row in grid_thw_list]
|
||||
grid_ws = [int(row[2]) for row in grid_thw_list]
|
||||
device = self.pos_embed.weight.device
|
||||
idx_list = [[] for _ in range(4)]
|
||||
weight_list = [[] for _ in range(4)]
|
||||
for t, h, w in grid_thw_list:
|
||||
h, w = int(h), int(w)
|
||||
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
|
||||
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
|
||||
h_idxs_floor = h_idxs.int()
|
||||
w_idxs_floor = w_idxs.int()
|
||||
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
dh = h_idxs - h_idxs_floor
|
||||
dw = w_idxs - w_idxs_floor
|
||||
base_h = h_idxs_floor * self.num_grid_per_side
|
||||
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
||||
indices = [
|
||||
(base_h[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h[None].T + w_idxs_ceil[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
|
||||
]
|
||||
weights = [
|
||||
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||
((1 - dh)[None].T * dw[None]).flatten(),
|
||||
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||
(dh[None].T * dw[None]).flatten(),
|
||||
]
|
||||
for j in range(4):
|
||||
idx_list[j].extend(indices[j].tolist())
|
||||
weight_list[j].extend(weights[j].tolist())
|
||||
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
|
||||
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
|
||||
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
|
||||
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||
patch_pos_embeds_permute = []
|
||||
merge_size = self.spatial_merge_size
|
||||
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
|
||||
pos_embed = pos_embed.repeat(t, 1)
|
||||
pos_embed = (
|
||||
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.flatten(0, 4)
|
||||
)
|
||||
patch_pos_embeds_permute.append(pos_embed)
|
||||
return torch.cat(patch_pos_embeds_permute)
|
||||
|
||||
def forward(self, x, grid_thw):
|
||||
x = self.patch_embed(x)
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
||||
x = x + pos_embeds
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
seq_len = x.shape[0]
|
||||
x = x.reshape(seq_len, -1)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().unsqueeze(-2)
|
||||
sin = emb.sin().unsqueeze(-2)
|
||||
sin_half = sin.shape[-1] // 2
|
||||
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
merged = self.merger(x)
|
||||
return merged
|
||||
|
||||
# Model Wrapper
|
||||
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
model_type = "qwen35_2b"
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = _make_config(self.model_type, config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
|
||||
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
|
||||
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
|
||||
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
|
||||
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
model_config = self.model.config
|
||||
past_key_values = []
|
||||
for i in range(model_config.num_hidden_layers):
|
||||
if model_config.layer_types[i] == "linear_attention":
|
||||
recurrent_state = torch.zeros(
|
||||
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
|
||||
device=device, dtype=torch.float32
|
||||
)
|
||||
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
|
||||
conv_state = torch.zeros(
|
||||
[batch, conv_dim, model_config.conv_kernel_size - 1],
|
||||
device=device, dtype=execution_dtype
|
||||
)
|
||||
past_key_values.append((recurrent_state, conv_state, 0))
|
||||
else:
|
||||
past_key_values.append((
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
0
|
||||
))
|
||||
return past_key_values
|
||||
|
||||
# Tokenizer and Text Encoder Wrappers
|
||||
|
||||
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
|
||||
from transformers import Qwen2Tokenizer
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
|
||||
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
|
||||
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||
image = kwargs.get("image", None)
|
||||
if image is not None and len(images) == 0:
|
||||
images = [image]
|
||||
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
skip_template = True
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is None:
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
if not thinking:
|
||||
llama_text += "<think>\n</think>\n"
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 248056: # <|image_pad|>
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
class Qwen35ClipModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
|
||||
class Qwen35_(Qwen35):
|
||||
pass
|
||||
Qwen35_.model_type = model_type
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
|
||||
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Qwen35TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
|
||||
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def tokenizer(model_type="qwen35_2b"):
|
||||
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
||||
return Qwen35ImageTokenizer_
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
|
||||
class Qwen35TEModel_(Qwen35TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
||||
return Qwen35TEModel_
|
||||
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
File diff suppressed because one or more lines are too long
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module):
|
||||
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
# Potentially important for spatially precise edits. This is present in the HF implementation.
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
hidden_states = hidden_states[reverse_indices, :]
|
||||
return hidden_states
|
||||
|
||||
@@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
CurvePoint,
|
||||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -13,4 +17,8 @@ __all__ = [
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
@@ -7,4 +8,8 @@ __all__ = [
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CurvePoint = tuple[float, float]
|
||||
|
||||
|
||||
class CurveInput(ABC):
|
||||
"""Abstract base class for curve inputs.
|
||||
|
||||
Subclasses represent different curve representations (control-point
|
||||
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||
uniform evaluation interface to downstream nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def points(self) -> list[CurvePoint]:
|
||||
"""The control points that define this curve."""
|
||||
|
||||
@abstractmethod
|
||||
def interp(self, x: float) -> float:
|
||||
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||
|
||||
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||
"""Vectorised evaluation over a numpy array of x values.
|
||||
|
||||
Subclasses should override this for better performance. The default
|
||||
falls back to scalar ``interp`` calls.
|
||||
"""
|
||||
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> CurveInput:
|
||||
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||
|
||||
Accepts:
|
||||
- A ``CurveInput`` instance (returned as-is).
|
||||
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||
"""
|
||||
if isinstance(data, CurveInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
raw_points = data["points"]
|
||||
interpolation = data.get("interpolation", "monotone_cubic")
|
||||
else:
|
||||
raw_points = data
|
||||
interpolation = "monotone_cubic"
|
||||
points = [(float(x), float(y)) for x, y in raw_points]
|
||||
if interpolation == "linear":
|
||||
return LinearCurve(points)
|
||||
if interpolation != "monotone_cubic":
|
||||
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||
return MonotoneCubicCurve(points)
|
||||
|
||||
|
||||
class MonotoneCubicCurve(CurveInput):
|
||||
"""Monotone cubic Hermite interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||
backend evaluation matches the editor preview exactly.
|
||||
|
||||
All heavy work (sorting, slope computation) happens once at construction.
|
||||
``interp_array`` is fully vectorised with numpy.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
self._slopes = self._compute_slopes()
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def _compute_slopes(self) -> np.ndarray:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return np.zeros(n, dtype=np.float64)
|
||||
|
||||
dx = np.diff(xs)
|
||||
dy = np.diff(ys)
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||
|
||||
slopes = np.empty(n, dtype=np.float64)
|
||||
slopes[0] = deltas[0]
|
||||
slopes[-1] = deltas[-1]
|
||||
for i in range(1, n - 1):
|
||||
if deltas[i - 1] * deltas[i] <= 0:
|
||||
slopes[i] = 0.0
|
||||
else:
|
||||
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||
|
||||
for i in range(n - 1):
|
||||
if deltas[i] == 0:
|
||||
slopes[i] = 0.0
|
||||
slopes[i + 1] = 0.0
|
||||
else:
|
||||
alpha = slopes[i] / deltas[i]
|
||||
beta = slopes[i + 1] / deltas[i]
|
||||
s = alpha * alpha + beta * beta
|
||||
if s > 9:
|
||||
t = 3 / math.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
return slopes
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
if x <= xs[0]:
|
||||
return float(ys[0])
|
||||
if x >= xs[-1]:
|
||||
return float(ys[-1])
|
||||
|
||||
hi = int(np.searchsorted(xs, x, side='right'))
|
||||
hi = min(hi, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
if dx == 0:
|
||||
return float(ys[lo])
|
||||
|
||||
t = (x - xs[lo]) / dx
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
"""Fully vectorised evaluation using numpy."""
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if n == 1:
|
||||
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||
|
||||
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
|
||||
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MonotoneCubicCurve(points={self._points})"
|
||||
|
||||
|
||||
class LinearCurve(CurveInput):
|
||||
"""Piecewise linear interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createLinearInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
return float(np.interp(x, xs, ys))
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
if len(self._xs) == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if len(self._xs) == 1:
|
||||
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||
return np.interp(xs_in, self._xs, self._ys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LinearCurve(points={self._points})"
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from comfy.samplers import CFGGuider, Sampler
|
||||
from comfy.sd import CLIP, VAE
|
||||
from comfy.sd import StyleModel as StyleModel_
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
@@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
from comfy_api.input import CurvePoint
|
||||
if TYPE_CHECKING:
|
||||
Type = CurveInput_
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
@@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.default is not None:
|
||||
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
Type = list[int]
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
@@ -2240,5 +2253,6 @@ __all__ = [
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
widgets.upscale = "enabled" ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0457, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
$enabled := widgets.upscale = "enabled";
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$isFast
|
||||
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||
: $enabled ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$fmt := {"approximate": true, "note": "(base)"};
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
$enabled := widgets.upscale = "enabled";
|
||||
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||
$isFast
|
||||
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||
: $enabled ? (
|
||||
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.input import CurveInput
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CurveEditor(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CurveEditor",
|
||||
display_name="Curve Editor",
|
||||
category="utils",
|
||||
inputs=[
|
||||
io.Curve.Input("curve"),
|
||||
io.Histogram.Input("histogram", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Curve.Output("curve"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||
result = CurveInput.from_raw(curve)
|
||||
|
||||
ui = {}
|
||||
if histogram is not None:
|
||||
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||
|
||||
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||
|
||||
|
||||
class CurveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [CurveEditor]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return CurveExtension()
|
||||
@@ -87,7 +87,9 @@ class SizeModeInput(TypedDict):
|
||||
|
||||
|
||||
MAX_IMAGES = 5 # u_image0-4
|
||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
||||
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
||||
MAX_BOOLS = 10 # u_bool0-9
|
||||
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||
|
||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||
@@ -497,6 +499,8 @@ def _render_shader_batch(
|
||||
image_batches: list[list[np.ndarray]],
|
||||
floats: list[float],
|
||||
ints: list[int],
|
||||
bools: list[bool] | None = None,
|
||||
curves: list[np.ndarray] | None = None,
|
||||
) -> list[list[np.ndarray]]:
|
||||
"""
|
||||
Render a fragment shader for multiple batches efficiently.
|
||||
@@ -511,6 +515,8 @@ def _render_shader_batch(
|
||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||
floats: List of float uniforms
|
||||
ints: List of int uniforms
|
||||
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
||||
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
||||
|
||||
Returns:
|
||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||
@@ -533,11 +539,17 @@ def _render_shader_batch(
|
||||
# Detect multi-pass rendering
|
||||
num_passes = _detect_pass_count(fragment_code)
|
||||
|
||||
if bools is None:
|
||||
bools = []
|
||||
if curves is None:
|
||||
curves = []
|
||||
|
||||
# Track resources for cleanup
|
||||
program = None
|
||||
fbo = None
|
||||
output_textures = []
|
||||
input_textures = []
|
||||
curve_textures = []
|
||||
ping_pong_textures = []
|
||||
ping_pong_fbos = []
|
||||
|
||||
@@ -624,6 +636,28 @@ def _render_shader_batch(
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, v)
|
||||
|
||||
for i, v in enumerate(bools):
|
||||
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, 1 if v else 0)
|
||||
|
||||
# Create 1D LUT textures for curves (bound after image texture units)
|
||||
for i, lut in enumerate(curves):
|
||||
tex = gl.glGenTextures(1)
|
||||
curve_textures.append(tex)
|
||||
unit = MAX_IMAGES + i
|
||||
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
||||
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||
|
||||
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
||||
if loc >= 0:
|
||||
gl.glUniform1i(loc, unit)
|
||||
|
||||
# Get u_pass uniform location for multi-pass
|
||||
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
||||
|
||||
@@ -718,6 +752,8 @@ def _render_shader_batch(
|
||||
|
||||
for tex in input_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in curve_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in output_textures:
|
||||
gl.glDeleteTextures(int(tex))
|
||||
for tex in ping_pong_textures:
|
||||
@@ -754,6 +790,20 @@ class GLSLShader(io.ComfyNode):
|
||||
max=MAX_UNIFORMS,
|
||||
)
|
||||
|
||||
bool_template = io.Autogrow.TemplatePrefix(
|
||||
io.Boolean.Input("bool", default=False),
|
||||
prefix="u_bool",
|
||||
min=0,
|
||||
max=MAX_BOOLS,
|
||||
)
|
||||
|
||||
curve_template = io.Autogrow.TemplatePrefix(
|
||||
io.Curve.Input("curve"),
|
||||
prefix="u_curve",
|
||||
min=0,
|
||||
max=MAX_CURVES,
|
||||
)
|
||||
|
||||
return io.Schema(
|
||||
node_id="GLSLShader",
|
||||
display_name="GLSL Shader",
|
||||
@@ -762,6 +812,7 @@ class GLSLShader(io.ComfyNode):
|
||||
"Apply GLSL ES fragment shaders to images. "
|
||||
"u_resolution (vec2) is always available."
|
||||
),
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"fragment_shader",
|
||||
@@ -796,6 +847,8 @@ class GLSLShader(io.ComfyNode):
|
||||
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
||||
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
||||
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
||||
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
||||
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
||||
@@ -813,13 +866,19 @@ class GLSLShader(io.ComfyNode):
|
||||
images: io.Autogrow.Type,
|
||||
floats: io.Autogrow.Type = None,
|
||||
ints: io.Autogrow.Type = None,
|
||||
bools: io.Autogrow.Type = None,
|
||||
curves: io.Autogrow.Type = None,
|
||||
**kwargs,
|
||||
) -> io.NodeOutput:
|
||||
|
||||
image_list = [v for v in images.values() if v is not None]
|
||||
float_list = (
|
||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||
)
|
||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
||||
|
||||
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
||||
|
||||
if not image_list:
|
||||
raise ValueError("At least one input image is required")
|
||||
@@ -846,6 +905,8 @@ class GLSLShader(io.ComfyNode):
|
||||
image_batches,
|
||||
float_list,
|
||||
int_list,
|
||||
bool_list,
|
||||
curve_luts,
|
||||
)
|
||||
|
||||
# Collect outputs into tensors
|
||||
|
||||
@@ -3,6 +3,7 @@ import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import math
|
||||
import numpy as np
|
||||
@@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
return io.NodeOutput(video_latent, audio_latent)
|
||||
|
||||
|
||||
class LTXVReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVReferenceAudio",
|
||||
display_name="LTXV Reference Audio (ID-LoRA)",
|
||||
category="conditioning/audio",
|
||||
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
|
||||
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
|
||||
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
|
||||
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||
# Encode reference audio to latents and patchify
|
||||
audio_latents = audio_vae.encode(reference_audio)
|
||||
b, c, t, f = audio_latents.shape
|
||||
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||
ref_audio = {"tokens": ref_tokens}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
|
||||
|
||||
# Patch model with identity guidance
|
||||
m = model.clone()
|
||||
scale = identity_guidance_scale
|
||||
model_sampling = m.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def post_cfg_function(args):
|
||||
if scale == 0:
|
||||
return args["denoised"]
|
||||
|
||||
sigma = args["sigma"]
|
||||
sigma_ = sigma[0].item()
|
||||
if sigma_ > sigma_start or sigma_ < sigma_end:
|
||||
return args["denoised"]
|
||||
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
# Strip ref_audio from conditioning for the no-reference pass
|
||||
noref_cond = []
|
||||
for entry in cond:
|
||||
new_entry = entry.copy()
|
||||
mc = new_entry.get("model_conds", {}).copy()
|
||||
mc.pop("ref_audio", None)
|
||||
new_entry["model_conds"] = mc
|
||||
noref_cond.append(new_entry)
|
||||
|
||||
(pred_noref,) = comfy.samplers.calc_cond_batch(
|
||||
args["model"], [noref_cond], x, sigma, model_options
|
||||
)
|
||||
|
||||
return cfg_result + (cond_pred - pred_noref) * scale
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return io.NodeOutput(m, positive, negative)
|
||||
|
||||
|
||||
class LtxvExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
|
||||
LTXVCropGuides,
|
||||
LTXVConcatAVLatent,
|
||||
LTXVSeparateAVLatent,
|
||||
LTXVReferenceAudio,
|
||||
]
|
||||
|
||||
|
||||
|
||||
92
comfy_extras/nodes_number_convert.py
Normal file
92
comfy_extras/nodes_number_convert.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Number Convert node for unified numeric type conversion.
|
||||
|
||||
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||
inputs into FLOAT and INT outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class NumberConvertNode(io.ComfyNode):
|
||||
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ComfyNumberConvert",
|
||||
display_name="Number Convert",
|
||||
category="math",
|
||||
search_aliases=[
|
||||
"int to float", "float to int", "number convert",
|
||||
"int2float", "float2int", "cast", "parse number",
|
||||
"string to number", "bool to int",
|
||||
],
|
||||
inputs=[
|
||||
io.MultiType.Input(
|
||||
"value",
|
||||
[io.Int, io.Float, io.String, io.Boolean],
|
||||
display_name="value",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Float.Output(display_name="FLOAT"),
|
||||
io.Int.Output(display_name="INT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
if isinstance(value, bool):
|
||||
float_val = 1.0 if value else 0.0
|
||||
int_val = 1 if value else 0
|
||||
elif isinstance(value, int):
|
||||
float_val = float(value)
|
||||
int_val = value
|
||||
elif isinstance(value, float):
|
||||
float_val = value
|
||||
int_val = int(value)
|
||||
elif isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
raise ValueError("Cannot convert empty string to number.")
|
||||
try:
|
||||
float_val = float(text)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Cannot convert string to number: {value!r}"
|
||||
) from None
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
try:
|
||||
int_val = int(text)
|
||||
except ValueError:
|
||||
int_val = int(float_val)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported input type: {type(value).__name__}"
|
||||
)
|
||||
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
|
||||
return io.NodeOutput(float_val, int_val)
|
||||
|
||||
|
||||
class NumberConvertExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [NumberConvertNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||
return NumberConvertExtension()
|
||||
@@ -67,11 +67,11 @@ class Blend(io.ComfyNode):
|
||||
def g(cls, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None, dtype=torch.float32):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
return (g / g.sum()).to(dtype)
|
||||
|
||||
class Blur(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -99,7 +99,7 @@ class Blur(io.ComfyNode):
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = blur_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype).repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||
@@ -200,7 +200,7 @@ class Sharpen(io.ComfyNode):
|
||||
image = image.to(comfy.model_management.get_torch_device())
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype) * -(alpha*10)
|
||||
kernel = kernel.to(dtype=image.dtype)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
|
||||
@@ -15,6 +15,7 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
||||
io.Float.Input("presence_penalty", optional=True, default=0.0, min=0.0, max=5.0, step=0.01),
|
||||
]
|
||||
),
|
||||
io.DynamicCombo.Option(
|
||||
@@ -25,7 +26,7 @@ class TextGenerate(io.ComfyNode):
|
||||
|
||||
return io.Schema(
|
||||
node_id="TextGenerate",
|
||||
category="textgen/",
|
||||
category="textgen",
|
||||
search_aliases=["LLM", "gemma"],
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
@@ -33,6 +34,7 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(display_name="generated_text"),
|
||||
@@ -40,9 +42,9 @@ class TextGenerate(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1)
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
@@ -52,6 +54,7 @@ class TextGenerate(io.ComfyNode):
|
||||
min_p = sampling_mode.get("min_p", 0.0)
|
||||
seed = sampling_mode.get("seed", None)
|
||||
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
||||
presence_penalty = sampling_mode.get("presence_penalty", 0.0)
|
||||
|
||||
generated_ids = clip.generate(
|
||||
tokens,
|
||||
@@ -62,6 +65,7 @@ class TextGenerate(io.ComfyNode):
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
@@ -156,12 +160,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
|
||||
@@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for lora.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"quantized_backward",
|
||||
default=False,
|
||||
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"algorithm",
|
||||
options=list(adapter_maps.keys()),
|
||||
@@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
quantized_backward,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
checkpoint_depth,
|
||||
@@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed = seed[0]
|
||||
training_dtype = training_dtype[0]
|
||||
lora_dtype = lora_dtype[0]
|
||||
quantized_backward = quantized_backward[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
offloading = offloading[0]
|
||||
@@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
|
||||
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
latents = _process_latents_bucket_mode(latents)
|
||||
@@ -1137,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
@@ -1145,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode):
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
use_grad_scaler = True
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
@@ -1156,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.18.2"
|
||||
__version__ = "0.18.1"
|
||||
|
||||
46
main.py
46
main.py
@@ -9,6 +9,8 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import register_output_files
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
@@ -192,7 +194,6 @@ if 'torch' in sys.modules:
|
||||
|
||||
|
||||
import comfy.utils
|
||||
from app.assets.seeder import asset_seeder
|
||||
|
||||
import execution
|
||||
import server
|
||||
@@ -240,6 +241,38 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||
"""Extract absolute file paths for output items from a history result."""
|
||||
paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for node_output in history_result.get("outputs", {}).values():
|
||||
for items in node_output.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type not in ("output", "temp"):
|
||||
continue
|
||||
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||
if base_dir is None:
|
||||
continue
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
filename = item.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
abs_path = os.path.abspath(
|
||||
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||
)
|
||||
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||
continue
|
||||
if abs_path not in seen:
|
||||
seen.add(abs_path)
|
||||
paths.append(abs_path)
|
||||
return paths
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
@@ -274,6 +307,7 @@ def prompt_worker(q, server_instance):
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
@@ -296,6 +330,10 @@ def prompt_worker(q, server_instance):
|
||||
else:
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
paths = _collect_output_absolute_paths(e.history_result)
|
||||
register_output_files(paths, job_id=prompt_id)
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
@@ -317,6 +355,9 @@ def prompt_worker(q, server_instance):
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
@@ -471,6 +512,9 @@ if __name__ == "__main__":
|
||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||
|
||||
if args.disable_dynamic_vram:
|
||||
logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.")
|
||||
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
x = start_all_func()
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b6
|
||||
comfyui_manager==4.1
|
||||
|
||||
2
nodes.py
2
nodes.py
@@ -2454,7 +2454,9 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.18.2"
|
||||
version = "0.18.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.41.21
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.38
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Base
|
||||
@@ -23,6 +23,21 @@ def db_engine():
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine_fk():
|
||||
"""In-memory SQLite engine with foreign key enforcement enabled."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def _set_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(db_engine):
|
||||
"""Session fixture for tests that need direct DB access."""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for asset enrichment (mime_type and hash population)."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.services.file_utils import get_mtime_ns
|
||||
from app.assets.scanner import (
|
||||
ENRICHMENT_HASHED,
|
||||
ENRICHMENT_METADATA,
|
||||
@@ -20,6 +22,13 @@ def _create_stub_asset(
|
||||
name: str | None = None,
|
||||
) -> tuple[Asset, AssetReference]:
|
||||
"""Create a stub asset with reference for testing enrichment."""
|
||||
# Use the real file's mtime so the optimistic guard in enrich_asset passes
|
||||
try:
|
||||
stat_result = os.stat(file_path, follow_symlinks=True)
|
||||
mtime_ns = get_mtime_ns(stat_result)
|
||||
except OSError:
|
||||
mtime_ns = 1234567890000000000
|
||||
|
||||
asset = Asset(
|
||||
id=asset_id,
|
||||
hash=None,
|
||||
@@ -35,7 +44,7 @@ def _create_stub_asset(
|
||||
name=name or f"test-asset-{asset_id}",
|
||||
owner_id="system",
|
||||
file_path=file_path,
|
||||
mtime_ns=1234567890000000000,
|
||||
mtime_ns=mtime_ns,
|
||||
enrichment_level=ENRICHMENT_STUB,
|
||||
)
|
||||
session.add(ref)
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""Tests for ingest services."""
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session as SASession, Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, Tag
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||
from app.assets.database.queries import get_reference_tags
|
||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
||||
from app.assets.services.ingest import (
|
||||
_ingest_file_from_path,
|
||||
_register_existing_asset,
|
||||
ingest_existing_file,
|
||||
)
|
||||
|
||||
|
||||
class TestIngestFileFromPath:
|
||||
@@ -235,3 +241,42 @@ class TestRegisterExistingAsset:
|
||||
|
||||
assert result.created is True
|
||||
assert set(result.tags) == {"alpha", "beta"}
|
||||
|
||||
|
||||
class TestIngestExistingFileTagFK:
|
||||
"""Regression: ingest_existing_file must seed Tag rows before inserting
|
||||
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
|
||||
|
||||
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
|
||||
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
|
||||
before they can be referenced in asset_reference_tags."""
|
||||
|
||||
@contextmanager
|
||||
def _create_session():
|
||||
with SASession(db_engine_fk) as sess:
|
||||
yield sess
|
||||
|
||||
file_path = temp_dir / "output.png"
|
||||
file_path.write_bytes(b"image data")
|
||||
|
||||
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||
patch(
|
||||
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
|
||||
return_value=("output.png", ["output"]),
|
||||
):
|
||||
result = ingest_existing_file(
|
||||
abs_path=str(file_path),
|
||||
extra_tags=["my-job"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
with SASession(db_engine_fk) as sess:
|
||||
tag_names = {t.name for t in sess.query(Tag).all()}
|
||||
assert "output" in tag_names
|
||||
assert "my-job" in tag_names
|
||||
|
||||
ref_tags = sess.query(AssetReferenceTag).all()
|
||||
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||
assert "output" in ref_tag_names
|
||||
assert "my-job" in ref_tag_names
|
||||
|
||||
81
tests-unit/assets_test/services/test_path_utils.py
Normal file
81
tests-unit/assets_test/services/test_path_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for path_utils – asset category resolution."""
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_dirs():
|
||||
"""Create temporary input, output, and temp directories."""
|
||||
with tempfile.TemporaryDirectory() as root:
|
||||
root_path = Path(root)
|
||||
input_dir = root_path / "input"
|
||||
output_dir = root_path / "output"
|
||||
temp_dir = root_path / "temp"
|
||||
models_dir = root_path / "models" / "checkpoints"
|
||||
for d in (input_dir, output_dir, temp_dir, models_dir):
|
||||
d.mkdir(parents=True)
|
||||
|
||||
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||
|
||||
with patch(
|
||||
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||
return_value=[("checkpoints", [str(models_dir)])],
|
||||
):
|
||||
yield {
|
||||
"input": input_dir,
|
||||
"output": output_dir,
|
||||
"temp": temp_dir,
|
||||
"models": models_dir,
|
||||
}
|
||||
|
||||
|
||||
class TestGetAssetCategoryAndRelativePath:
|
||||
def test_input_file(self, fake_dirs):
|
||||
f = fake_dirs["input"] / "photo.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "input"
|
||||
assert rel == "photo.png"
|
||||
|
||||
def test_output_file(self, fake_dirs):
|
||||
f = fake_dirs["output"] / "result.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "output"
|
||||
assert rel == "result.png"
|
||||
|
||||
def test_temp_file(self, fake_dirs):
|
||||
"""Regression: temp files must be categorised, not raise ValueError."""
|
||||
f = fake_dirs["temp"] / "GLSLShader_output_00004_.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "temp"
|
||||
assert rel == "GLSLShader_output_00004_.png"
|
||||
|
||||
def test_temp_file_in_subfolder(self, fake_dirs):
|
||||
sub = fake_dirs["temp"] / "sub"
|
||||
sub.mkdir()
|
||||
f = sub / "ComfyUI_temp_tczip_00004_.png"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "temp"
|
||||
assert os.path.normpath(rel) == os.path.normpath("sub/ComfyUI_temp_tczip_00004_.png")
|
||||
|
||||
def test_model_file(self, fake_dirs):
|
||||
f = fake_dirs["models"] / "model.safetensors"
|
||||
f.touch()
|
||||
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||
assert cat == "models"
|
||||
|
||||
def test_unknown_path_raises(self, fake_dirs):
|
||||
with pytest.raises(ValueError, match="not within"):
|
||||
get_asset_category_and_relative_path("/some/random/path.png")
|
||||
180
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
180
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||
|
||||
|
||||
class TestNumberConvertExecute:
|
||||
@staticmethod
|
||||
def _exec(value) -> object:
|
||||
return NumberConvertNode.execute(value)
|
||||
|
||||
# --- INT input ---
|
||||
|
||||
def test_int_input(self):
|
||||
result = self._exec(42)
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_int_zero(self):
|
||||
result = self._exec(0)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
def test_int_negative(self):
|
||||
result = self._exec(-7)
|
||||
assert result[0] == -7.0
|
||||
assert result[1] == -7
|
||||
|
||||
# --- FLOAT input ---
|
||||
|
||||
def test_float_input(self):
|
||||
result = self._exec(3.14)
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_float_truncation_toward_zero(self):
|
||||
result = self._exec(-2.9)
|
||||
assert result[0] == -2.9
|
||||
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||
|
||||
def test_float_output_type(self):
|
||||
result = self._exec(5)
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
def test_int_output_type(self):
|
||||
result = self._exec(5.7)
|
||||
assert isinstance(result[1], int)
|
||||
|
||||
# --- BOOL input ---
|
||||
|
||||
def test_bool_true(self):
|
||||
result = self._exec(True)
|
||||
assert result[0] == 1.0
|
||||
assert result[1] == 1
|
||||
|
||||
def test_bool_false(self):
|
||||
result = self._exec(False)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
# --- STRING input ---
|
||||
|
||||
def test_string_integer(self):
|
||||
result = self._exec("42")
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_string_float(self):
|
||||
result = self._exec("3.14")
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative(self):
|
||||
result = self._exec("-5.5")
|
||||
assert result[0] == -5.5
|
||||
assert result[1] == -5
|
||||
|
||||
def test_string_with_whitespace(self):
|
||||
result = self._exec(" 7.0 ")
|
||||
assert result[0] == 7.0
|
||||
assert result[1] == 7
|
||||
|
||||
def test_string_scientific_notation(self):
|
||||
result = self._exec("1e3")
|
||||
assert result[0] == 1000.0
|
||||
assert result[1] == 1000
|
||||
|
||||
# --- Large number precision (string input) ---
|
||||
|
||||
def test_string_large_int_above_2_53(self):
|
||||
"""Text-to-int must not lose precision for integers beyond 2^53."""
|
||||
big = 2**53 + 1 # 9007199254740993
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_large_negative_int_above_2_53(self):
|
||||
big = -(2**53 + 1)
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_very_large_int(self):
|
||||
big = 2**63 + 42
|
||||
result = self._exec(str(big))
|
||||
assert result[1] == big
|
||||
|
||||
def test_string_large_int_float_output_is_float(self):
|
||||
"""FLOAT output is still a float (may lose precision, but must be float type)."""
|
||||
result = self._exec(str(2**53 + 1))
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
# --- Large number precision (int input) ---
|
||||
|
||||
def test_int_large_above_2_53(self):
|
||||
"""Native int input must preserve its value in the INT output."""
|
||||
big = 2**53 + 1
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
def test_int_large_negative_above_2_53(self):
|
||||
big = -(2**53 + 1)
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
def test_int_very_large(self):
|
||||
big = 2**100
|
||||
result = self._exec(big)
|
||||
assert result[1] == big
|
||||
|
||||
# --- String decimal / scientific notation fallback ---
|
||||
|
||||
def test_string_decimal_still_truncates(self):
|
||||
"""Strings with decimal points fall back to int(float(...)) truncation."""
|
||||
result = self._exec("3.7")
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative_decimal_truncates(self):
|
||||
result = self._exec("-2.9")
|
||||
assert result[1] == -2
|
||||
|
||||
def test_string_scientific_large(self):
|
||||
result = self._exec("1e18")
|
||||
assert result[0] == 1e18
|
||||
assert result[1] == 10**18
|
||||
|
||||
# --- STRING error paths ---
|
||||
|
||||
def test_empty_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec("")
|
||||
|
||||
def test_whitespace_only_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec(" ")
|
||||
|
||||
def test_non_numeric_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||
self._exec("abc")
|
||||
|
||||
def test_string_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("inf")
|
||||
|
||||
def test_string_nan_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("nan")
|
||||
|
||||
def test_string_negative_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("-inf")
|
||||
|
||||
# --- Unsupported type ---
|
||||
|
||||
def test_unsupported_type_raises(self):
|
||||
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||
self._exec([1, 2, 3])
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -771,6 +772,188 @@ class TestSeederStopRestart:
|
||||
assert collected_roots[1] == ("input",)
|
||||
|
||||
|
||||
class TestEnqueueEnrichHandoff:
|
||||
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
|
||||
|
||||
def test_pending_enrich_runs_after_scan_completes(
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
"""A queued enrich request runs automatically when a scan finishes."""
|
||||
enrich_roots_seen: list[tuple] = []
|
||||
original_start = fresh_seeder.start
|
||||
|
||||
def tracking_start(*args, **kwargs):
|
||||
phase = kwargs.get("phase")
|
||||
roots = kwargs.get("roots", args[0] if args else None)
|
||||
result = original_start(*args, **kwargs)
|
||||
if phase == ScanPhase.ENRICH and result:
|
||||
enrich_roots_seen.append(roots)
|
||||
return result
|
||||
|
||||
fresh_seeder.start = tracking_start
|
||||
|
||||
# Start a fast scan, then enqueue an enrich while it's running
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
reached.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert reached.wait(timeout=2.0)
|
||||
|
||||
queued = fresh_seeder.enqueue_enrich(
|
||||
roots=("input",), compute_hashes=True
|
||||
)
|
||||
assert queued is False # queued, not started immediately
|
||||
|
||||
barrier.set()
|
||||
|
||||
# Wait for the original scan + the auto-started enrich scan
|
||||
deadline = time.monotonic() + 5.0
|
||||
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||
time.sleep(0.05)
|
||||
|
||||
assert enrich_roots_seen == [("input",)]
|
||||
|
||||
def test_enqueue_enrich_during_drain_does_not_lose_work(
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
"""enqueue_enrich called concurrently with drain cannot drop work.
|
||||
|
||||
Simulates the race: another thread calls enqueue_enrich right as the
|
||||
scan thread is draining _pending_enrich. The enqueue must either be
|
||||
picked up by the draining scan or successfully start its own scan.
|
||||
"""
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
enrich_started = threading.Event()
|
||||
|
||||
enrich_call_count = 0
|
||||
|
||||
def slow_collect(*args):
|
||||
reached.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
# Track how many times start_enrich actually fires
|
||||
real_start_enrich = fresh_seeder.start_enrich
|
||||
enrich_roots_seen: list[tuple] = []
|
||||
|
||||
def tracking_start_enrich(**kwargs):
|
||||
nonlocal enrich_call_count
|
||||
enrich_call_count += 1
|
||||
enrich_roots_seen.append(kwargs.get("roots"))
|
||||
result = real_start_enrich(**kwargs)
|
||||
if result:
|
||||
enrich_started.set()
|
||||
return result
|
||||
|
||||
fresh_seeder.start_enrich = tracking_start_enrich
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
# Start a scan
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert reached.wait(timeout=2.0)
|
||||
|
||||
# Queue an enrich while scan is running
|
||||
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
# Let scan finish — drain will fire start_enrich atomically
|
||||
barrier.set()
|
||||
|
||||
# Wait for drain to complete and the enrich scan to start
|
||||
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
|
||||
assert ("output",) in enrich_roots_seen
|
||||
|
||||
def test_concurrent_enqueue_during_drain_not_lost(
|
||||
self, fresh_seeder: _AssetSeeder,
|
||||
):
|
||||
"""A second enqueue_enrich arriving while drain is in progress is not lost.
|
||||
|
||||
Because the drain now holds _lock through the start_enrich call,
|
||||
a concurrent enqueue_enrich will block until start_enrich has
|
||||
transitioned state to RUNNING, then the enqueue will queue its
|
||||
payload as _pending_enrich for the *next* drain.
|
||||
"""
|
||||
scan_barrier = threading.Event()
|
||||
scan_reached = threading.Event()
|
||||
enrich_barrier = threading.Event()
|
||||
enrich_reached = threading.Event()
|
||||
|
||||
collect_call = 0
|
||||
|
||||
def gated_collect(*args):
|
||||
nonlocal collect_call
|
||||
collect_call += 1
|
||||
if collect_call == 1:
|
||||
# First call: the initial fast scan
|
||||
scan_reached.set()
|
||||
scan_barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
enrich_call = 0
|
||||
|
||||
def gated_get_unenriched(*args, **kwargs):
|
||||
nonlocal enrich_call
|
||||
enrich_call += 1
|
||||
if enrich_call == 1:
|
||||
# First enrich batch: signal and block
|
||||
enrich_reached.set()
|
||||
enrich_barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
):
|
||||
# 1. Start fast scan
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert scan_reached.wait(timeout=2.0)
|
||||
|
||||
# 2. Queue enrich while fast scan is running
|
||||
queued = fresh_seeder.enqueue_enrich(
|
||||
roots=("input",), compute_hashes=False
|
||||
)
|
||||
assert queued is False
|
||||
|
||||
# 3. Let the fast scan finish — drain will start the enrich scan
|
||||
scan_barrier.set()
|
||||
|
||||
# 4. Wait until the drained enrich scan is running
|
||||
assert enrich_reached.wait(timeout=5.0)
|
||||
|
||||
# 5. Now enqueue another enrich while the drained scan is running
|
||||
queued2 = fresh_seeder.enqueue_enrich(
|
||||
roots=("output",), compute_hashes=True
|
||||
)
|
||||
assert queued2 is False # should be queued, not started
|
||||
|
||||
# Verify _pending_enrich was set (the second enqueue was captured)
|
||||
with fresh_seeder._lock:
|
||||
assert fresh_seeder._pending_enrich is not None
|
||||
assert "output" in fresh_seeder._pending_enrich["roots"]
|
||||
|
||||
# Let the enrich scan finish
|
||||
enrich_barrier.set()
|
||||
|
||||
deadline = time.monotonic() + 5.0
|
||||
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||
return UnenrichedReferenceRow(
|
||||
reference_id=ref_id, asset_id=asset_id,
|
||||
|
||||
250
tests/test_asset_seeder.py
Normal file
250
tests/test_asset_seeder.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def seeder():
|
||||
"""Fresh seeder instance for each test."""
|
||||
return _AssetSeeder()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reset_to_idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetToIdle:
|
||||
def test_sets_idle_and_clears_progress(self, seeder):
|
||||
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||
seeder._state = State.RUNNING
|
||||
seeder._progress = progress
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is progress
|
||||
|
||||
def test_noop_when_progress_already_none(self, seeder):
|
||||
"""_reset_to_idle should handle None progress gracefully."""
|
||||
seeder._state = State.CANCELLING
|
||||
seeder._progress = None
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – immediate start when idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichStartsImmediately:
|
||||
def test_starts_when_idle(self, seeder):
|
||||
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||
|
||||
def test_no_pending_when_started_immediately(self, seeder):
|
||||
"""No pending request should be stored when start_enrich succeeds."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True):
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – queuing when busy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichQueuesWhenBusy:
|
||||
def test_queues_when_busy(self, seeder):
|
||||
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
|
||||
assert result is False
|
||||
assert seeder._pending_enrich == {
|
||||
"roots": ("models",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – merging when a pending request already exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichMergesPending:
|
||||
def _make_busy(self, seeder):
|
||||
"""Patch start_enrich to always return False (seeder busy)."""
|
||||
return patch.object(seeder, "start_enrich", return_value=False)
|
||||
|
||||
def test_merges_roots(self, seeder):
|
||||
"""A second enqueue should merge roots with the existing pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",))
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "output"}
|
||||
|
||||
def test_merges_overlapping_roots(self, seeder):
|
||||
"""Duplicate roots should be deduplicated."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models", "input"))
|
||||
seeder.enqueue_enrich(roots=("input", "output"))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
|
||||
def test_compute_hashes_sticky_true(self, seeder):
|
||||
"""Once compute_hashes is True it should stay True after merging."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_stays_false(self, seeder):
|
||||
"""If both enqueues have compute_hashes=False it stays False."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is False
|
||||
|
||||
def test_triple_merge(self, seeder):
|
||||
"""Three successive enqueues should all merge correctly."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending enrich drains after scan completes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPendingEnrichDrain:
|
||||
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||
"""After a fast scan finishes, the pending enrich should be started."""
|
||||
seeder = _AssetSeeder()
|
||||
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": True,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_called_once_with(
|
||||
roots=("output",),
|
||||
compute_hashes=True,
|
||||
)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||
seeder = _AssetSeeder()
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_no_drain_when_no_pending(self, *_mocks):
|
||||
"""start_enrich should not be called when there is no pending request."""
|
||||
seeder = _AssetSeeder()
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread-safety of enqueue_enrich
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichThreadSafety:
|
||||
def test_concurrent_enqueues(self, seeder):
|
||||
"""Multiple threads enqueuing should not lose roots."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
barrier = threading.Barrier(3)
|
||||
|
||||
def enqueue(root):
|
||||
barrier.wait()
|
||||
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=enqueue, args=(r,))
|
||||
for r in ("models", "input", "output")
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
@@ -24,6 +24,7 @@ def init_mime_types():
|
||||
# Web types (used by server.py for static file serving)
|
||||
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
||||
mimetypes.add_type('image/webp', '.webp')
|
||||
mimetypes.add_type('image/svg+xml', '.svg')
|
||||
|
||||
# Model and data file types (used by asset scanning / metadata extraction)
|
||||
mimetypes.add_type("application/safetensors", ".safetensors")
|
||||
|
||||
Reference in New Issue
Block a user