Compare commits

..

27 Commits

Author SHA1 Message Date
Luke Mino-Altherr
ad5604fb0b Remove unused enable_safetensors parameter from extract_file_metadata
Amp-Thread-ID: https://ampcode.com/threads/T-019ccb0b-2980-74fc-b62f-5fce0f658d8e
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 17:29:22 -08:00
Luke Mino-Altherr
7f00f48c96 Consolidate hash functions into single implementation
Extract file open/seek/restore logic into _open_for_hashing context
manager and use a single hash loop in compute_blake3_hash for both
file paths and file objects.

Amp-Thread-ID: https://ampcode.com/threads/T-019ccb05-0db1-7206-8bd9-1c2efb898fef
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 17:29:12 -08:00
Luke Mino-Altherr
42edf71854 change logging.error to logging.exception 2026-03-07 16:37:13 -08:00
Luke Mino-Altherr
9c2a423aec Defer asset_seeder resume until GC interval elapses
Only resume the asset scanner when the needs_gc time condition is
satisfied, preventing the scanner from restarting between rapid
successive prompt executions.

Amp-Thread-ID: https://ampcode.com/threads/T-019cc637-2352-7139-b753-47c19f43b55c
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
731a95eb13 Fix five code review issues
1. Seeder pause/resume: only resume after prompt execution if pause()
   returned True, preventing undo of user-initiated pauses.

2. Missing rollback in enrich_assets_batch: add sess.rollback() in
   exception handler to prevent broken session state for subsequent
   batch operations.

3. Hash checkpoint validation: store mtime_ns/file_size in
   HashCheckpoint and re-stat on resume instead of comparing the same
   stat result to itself.

4. Scan progress preserved: save _last_progress before clearing
   _progress in finally blocks so wait=true endpoint returns final
   stats instead of zeros.

5. Download XSS hardening: block dangerous MIME types (matching
   server.py) and add X-Content-Type-Options: nosniff header to
   asset content endpoint.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbb6b-e97b-776d-8c43-2de8acd0d09e
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
781d451355 fix: support in-memory SQLite and ensure custom MIME types are initialized
- Add _init_memory_db() path using Base.metadata.create_all + StaticPool
  since Alembic migrations don't work with in-memory SQLite (each
  connection gets its own separate database)
- Call init_mime_types() at module load in metadata_extract so custom
  types like application/safetensors are always registered

Amp-Thread-ID: https://ampcode.com/threads/T-019cbb5f-13d1-7429-8cfd-815625c4d032
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
1f1894608d Centralize MIME type initialization into utils/mime_types.py
Move mimetypes.init() and all custom type registrations from server.py
and metadata_extract.py into a single init_mime_types() function called
once at startup in main.py.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbb2a-513a-7458-9962-b4100e4f124d
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
657bf5a55e Add checkpoint/interrupt support to BLAKE3 hashing
- Add HashCheckpoint dataclass for saving/resuming interrupted hash computations
- compute_blake3_hash now accepts interrupt_check and checkpoint parameters
- Returns (digest, None) on completion or (None, checkpoint) on interruption
- Update ingest.py caller to handle new tuple return type

Amp-Thread-ID: https://ampcode.com/threads/T-019cbb0b-8563-7199-b628-33e3c4fe9f41
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
14183b3c21 Optimize enrichment: shared DB session per batch, add fast scan timing logs
- Add debug timing logs for each fast scan sub-step (sync_root, collect_paths, build_asset_specs) and info-level total timing
- Refactor enrich_asset to accept a session parameter instead of creating one per file
- enrich_assets_batch now opens one session for the entire batch, committing after each asset to keep transactions short
- Simplify enrichment tests by removing create_session mocking

Amp-Thread-ID: https://ampcode.com/threads/T-019cbb0b-8563-7199-b628-33e3c4fe9f41
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
373b2a735e Replace SQLite exclusive lock with cross-platform file lock
- Use filelock (FileLock) instead of PRAGMA locking_mode=EXCLUSIVE to
  prevent multi-process database access. The OS automatically releases
  the lock when the process exits, even on crashes or Ctrl+C.
- Add friendly error messages for database-is-locked and general
  database init failures when --enable-assets is set.
- Exit the process instead of silently disabling assets when the user
  explicitly passed --enable-assets and the database fails.
- Add filelock to requirements.txt.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbab8-50d4-748c-9669-2506575dda44
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
a130ccc942 feat: soft-delete for AssetReference with scanner persistence
- Add deleted_at column to AssetReference model and migration
- soft_delete_reference_by_id sets deleted_at instead of removing rows
- DELETE /api/assets/{id} defaults to soft-delete; delete_content=true
  for hard-delete
- Add deleted_at IS NULL filters to read queries, tag queries, and
  scanner queries so soft-deleted refs are invisible
- restore_references_by_paths skips soft-deleted refs
- upsert_reference clears deleted_at on explicit re-ingest
- Add tests for soft-delete API behavior, scanner persistence, bulk
  insert, enrichment exclusion, and seed asset garbage collection

Amp-Thread-ID: https://ampcode.com/threads/T-019cb6fc-c05c-761f-b855-6d5d1c9defa2
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
a8371ef1bc fix: acquire exclusive DB lock after migrations to avoid self-deadlock
The previous commit acquired the exclusive lock before Alembic migrations,
but Alembic opens its own connection — which was then blocked by our lock.
Move lock acquisition to after migrations complete in a dedicated connection.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
d653b86bd7 Fix asset API security and correctness issues
- Content-Disposition: drop raw filename= parameter, use only RFC 5987
  filename*=UTF-8'' to prevent header injection via ; and special chars
- delete_asset: default delete_content to False (non-destructive) when
  query parameter is omitted
- create_asset_from_hash: return 400 MISSING_INPUT instead of 404 when
  hash not found and no file uploaded (client input error, not missing resource)
- seeder: clear _progress when returning to IDLE so get_status() does not
  return stale progress after scan completion
- hashing: handle non-seekable streams in _hash_file_obj by checking
  seekable() before attempting tell/seek
- bulk_ingest: filter lost_paths to only include paths tied to actually
  inserted asset IDs, preventing inflated counts from ON CONFLICT drops

Amp-Thread-ID: https://ampcode.com/threads/T-019cb67a-9822-7438-ab05-d09991a9f7f3
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
f26384f371 fix: acquire exclusive SQLite lock at startup to prevent multi-process DB conflicts
When two ComfyUI processes share the same database file but point to
different input/output/model directories, each process's scan marks
the other's assets as missing, causing unreliable asset visibility.
This adds an exclusive lock so the second process fails fast at startup
with a clear message to use --database-url.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
bfdb78da05 Reduce duplication across assets module
- Extract validate_blake3_hash() into helpers.py, used by upload, schemas, routes
- Extract get_reference_with_owner_check() into queries, used by 4 service functions
- Extract build_prefix_like_conditions() into queries/common.py, used by 3 queries
- Replace 3 inlined tag queries with get_reference_tags() calls
- Consolidate AddTagsDict/RemoveTagsDict TypedDicts into AddTagsResult/RemoveTagsResult
  dataclasses, eliminating manual field copying in tagging.py
- Make iter_row_chunks delegate to iter_chunks
- Inline trivial compute_filename_for_reference wrapper (unused session param)
- Remove mark_assets_missing_outside_prefixes pass-through in bulk_ingest.py
- Clean up unused imports (os, time, dependencies_available)
- Disable assets routes on DB init failure in main.py

Amp-Thread-ID: https://ampcode.com/threads/T-019cb649-dd4e-71ff-9a0e-ae517365207b
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
e59fbc101d fix: address code review feedback - round 2
- Reject path separators (/, \, os.sep) in tag components for defense-in-depth
- Add comment explaining double-relpath normalization trick
- Add _require_assets_feature_enabled decorator returning 503 when disabled
- Call asset_seeder.disable() when --enable-assets is not passed
- Add iter_chunks to bulk_update_needs_verify, bulk_update_is_missing,
  and delete_references_by_ids to respect SQLite bind param limits
- Fix CacheStateRow.size_bytes NULL coercion (0 -> None) to avoid
  false needs_verify flags on assets with unknown size
- Add PermissionError catch in delete_asset_tags route (403 vs 500)
- Add hash-is-None guard in delete_orphaned_seed_asset
- Validate from_asset_id in reassign_asset_references
- Initialize _prune_first in __init__, remove getattr workaround
- Cap error accumulation in _add_error to 200
- Remove confirmed dead code: seed_assets, compute_filename_for_asset,
  ALLOWED_ROOTS, AssetNotFoundError, SetTagsResult, update_enrichment_level,
  Asset.to_dict, AssetReference.to_dict, _AssetSeeder.enable

Amp-Thread-ID: https://ampcode.com/threads/T-019cb610-1b55-74b6-8dbb-381d73c387c0
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
defd97d8b8 fix: address code review feedback
- Fix missing import for compute_filename_for_reference in ingest.py
- Apply code review fixes across routes, queries, scanner, seeder,
  hashing, ingest, path_utils, main, and server
- Update and add tests for sync references and seeder

Amp-Thread-ID: https://ampcode.com/threads/T-019cb61a-ed54-738c-a05f-9b5242e513f3
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
a611444b82 feat: add --enable-assets flag, disable assets by default, expose to frontend
Replace --disable-assets-autoscan with --enable-assets so the assets
system (API routes, database sync, background scanning) is off by
default and must be explicitly opted into. Expose the flag as an
"assets" entry in SERVER_FEATURE_FLAGS so the frontend can read it
from GET /features.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
7a54eb33ca fix: update test to expect ScanInProgressError when marking missing during active scan
Amp-Thread-ID: https://ampcode.com/threads/T-019c92af-47c7-7448-b111-4ebfbf5585e6
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
c3cc3ba24f Exclude hidden files and custom_nodes folder from asset scanning
- Filter hidden files/directories (dot-prefixed) in collect_models_files()
  using is_visible(), matching the existing behavior for input/output roots
- Exclude the 'custom_nodes' folder name from get_comfy_models_folders();
  custom nodes that register their own paths under other folder names
  will still be scanned as expected

Amp-Thread-ID: https://ampcode.com/threads/T-019c924b-591a-725e-b8b7-0d49ba1a5591
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
09730315d2 fix: replace os.path.commonpath with Path.is_relative_to for cross-drive safety
commonpath raises ValueError on Windows when comparing paths on different
drives (e.g. C:\models vs D:\extra_models). Replace all usages in the
asset scanner with Path.is_relative_to() which handles cross-drive paths,
case-insensitivity, and prefix traps natively without try/except.

Amp-Thread-ID: https://ampcode.com/threads/T-019c9224-d83c-7797-8c02-e1e1ae2ee452
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
05ed9e774a fix: include all model categories in scanning, not just those under models_dir
get_comfy_models_folders() previously filtered by startswith(models_root),
excluding extra model paths outside the main models directory. Now includes
every category with non-empty paths from folder_names_and_paths.

Amp-Thread-ID: https://ampcode.com/threads/T-019c9224-d83c-7797-8c02-e1e1ae2ee452
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
0fff4c980f fix: follow symlinks in list_files_recursively with cycle detection
list_files_recursively now uses followlinks=True so symlinked
directories under input/ and output/ roots are traversed, matching
the existing behavior of folder_paths.recursive_search for models.

Tracks (st_dev, st_ino) pairs of visited directories to detect and
break circular symlink loops safely.

Amp-Thread-ID: https://ampcode.com/threads/T-019c9220-21b8-7678-b428-9215ff1bb011
Co-authored-by: Amp <amp@ampcode.com>
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
b98e727582 feat(assets): async two-phase scanner and background seeder
- Rewrite scanner.py with two-phase scanning architecture (fast scan + enrich)
- Add AssetSeeder for non-blocking background startup scanning
- Implement pause/resume/stop/restart controls and disable/enable for --disable-assets-autoscan
- Add non-destructive asset pruning with is_missing flag
- Wire seeder into main.py and server.py lifecycle
- Skip hidden files/directories, populate mime_type, optional blake3 hashing
- Add comprehensive seeder tests

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019c9209-37af-757a-b6e4-af59b4267362
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
315aa8c3bf refactor(assets): API routes call services directly, extract upload handling
- Refactor routes.py to call service functions directly (no manager layer)
- Extract multipart upload parsing into upload.py
- Update API schemas
- Fix path traversal validation to return 400 instead of 500
- Rename test_tags.py to test_tags_api.py
- Update existing API-level tests

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019c9209-37af-757a-b6e4-af59b4267362
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
d621657143 refactor(assets): extract services layer from manager and helpers
- Create services/ package: asset_management, bulk_ingest, file_utils, hashing, ingest, metadata_extract, path_utils, schemas, tagging
- Move business logic out of helpers.py into service modules
- Remove manager.py and hashing.py (absorbed into services)
- Add blake3 to requirements.txt
- Add comprehensive service-layer tests

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019c9209-37af-757a-b6e4-af59b4267362
2026-03-07 15:13:17 -08:00
Luke Mino-Altherr
d280ae140f refactor(assets): database layer — split queries into modules and merge migrations
- Split monolithic queries.py into modular query modules (asset, asset_reference, common, tags)
- Absorb bulk_ops.py and tags.py into query modules
- Merge migrations 0002-0005 into single migration (0002_merge_to_asset_references)
- Update models.py (merge AssetInfo/AssetCacheState into AssetReference)
- Enable SQLite foreign key enforcement
- Add comprehensive query-layer tests

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019c917d-82b5-7448-a04f-9cd59c69d0a2
2026-03-07 15:13:17 -08:00
22 changed files with 51 additions and 643 deletions

View File

@@ -3,12 +3,8 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import IO, Any, Callable, Iterator from typing import IO, Any, Callable, Iterator
import logging
try: from blake3 import blake3
from blake3 import blake3
except ModuleNotFoundError:
logging.warning("WARNING: blake3 package not installed")
DEFAULT_CHUNK = 8 * 1024 * 1024 DEFAULT_CHUNK = 8 * 1024 * 1024

View File

@@ -223,19 +223,12 @@ class DoubleStreamBlock(nn.Module):
del txt_k, img_k del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2) v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v del txt_v, img_v
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# run actual attention # run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v del q, k, v
if "attn1_output_patch" in transformer_patches: if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"] patch = transformer_patches["attn1_output_patch"]
for p in patch: for p in patch:
attn = p(attn, extra_options) attn = p(attn, extra_options)
@@ -328,12 +321,6 @@ class SingleStreamBlock(nn.Module):
del qkv del qkv
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v del q, k, v

View File

@@ -31,8 +31,6 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def _apply_rope1(x: Tensor, freqs_cis: Tensor): def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
x_out = freqs_cis[..., 0] * x_[..., 0] x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])

View File

@@ -170,7 +170,7 @@ class Flux(nn.Module):
if "post_input" in patches: if "post_input" in patches:
for p in patches["post_input"]: for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
img = out["img"] img = out["img"]
txt = out["txt"] txt = out["txt"]
img_ids = out["img_ids"] img_ids = out["img_ids"]

View File

@@ -372,8 +372,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2 del s2
break break
except Exception as e: except model_management.OOM_EXCEPTION as e:
model_management.raise_non_oom(e)
if first_op_done == False: if first_op_done == False:
model_management.soft_empty_cache(True) model_management.soft_empty_cache(True)
if cleared_cache == False: if cleared_cache == False:

View File

@@ -258,8 +258,7 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2) r1[:, :, i:end] = torch.bmm(v, s2)
del s2 del s2
break break
except Exception as e: except model_management.OOM_EXCEPTION as e:
model_management.raise_non_oom(e)
model_management.soft_empty_cache(True) model_management.soft_empty_cache(True)
steps *= 2 steps *= 2
if steps > 128: if steps > 128:
@@ -315,8 +314,7 @@ def pytorch_attention(q, k, v):
try: try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape) out = out.transpose(2, 3).reshape(orig_shape)
except Exception as e: except model_management.OOM_EXCEPTION:
model_management.raise_non_oom(e)
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True oom_fallback = True
if oom_fallback: if oom_fallback:

View File

@@ -169,8 +169,7 @@ def _get_attention_scores_no_kv_chunking(
try: try:
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
except Exception as e: except model_management.OOM_EXCEPTION:
model_management.raise_non_oom(e)
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)

View File

@@ -149,9 +149,6 @@ class Attention(nn.Module):
seq_img = hidden_states.shape[1] seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1]
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# Project and reshape to BHND format (batch, heads, seq, dim) # Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
@@ -170,22 +167,15 @@ class Attention(nn.Module):
joint_key = torch.cat([txt_key, img_key], dim=2) joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2) joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None: if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else: else:
attn_mask = None attn_mask = None
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options, attn_mask, transformer_options=transformer_options,
skip_reshape=True) skip_reshape=True)
@@ -454,7 +444,6 @@ class QwenImageTransformer2DModel(nn.Module):
timestep_zero_index = None timestep_zero_index = None
if ref_latents is not None: if ref_latents is not None:
ref_num_tokens = []
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
@@ -485,16 +474,16 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1) hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero: if timestep_zero:
if index > 0: if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0) timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds timestep_zero_index = num_embeds
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@@ -506,18 +495,6 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):

View File

@@ -99,9 +99,6 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk: for k in sdk:
if k.endswith(".weight"): if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
if tp > 0 and not k.startswith("clip_"):
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False clip_l_present = False

View File

@@ -1,5 +1,4 @@
import json import json
import comfy.memory_management
import comfy.supported_models import comfy.supported_models
import comfy.supported_models_base import comfy.supported_models_base
import comfy.utils import comfy.utils
@@ -1119,13 +1118,8 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
new[:old_weight.shape[0]] = old_weight new[:old_weight.shape[0]] = old_weight
old_weight = new old_weight = new
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
old_weight = old_weight.clone()
w = old_weight.narrow(offset[0], offset[1], offset[2]) w = old_weight.narrow(offset[0], offset[1], offset[2])
else: else:
if comfy.memory_management.aimdo_enabled:
weight = weight.clone()
old_weight = weight old_weight = weight
w = weight w = weight
w[:] = fun(weight) w[:] = fun(weight)

View File

@@ -270,23 +270,6 @@ try:
except: except:
OOM_EXCEPTION = Exception OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
discard_cuda_async_error()
return True
return False
def raise_non_oom(e):
if not is_oom(e):
raise e
XFORMERS_VERSION = "" XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True XFORMERS_ENABLED_VAE = True
if args.disable_xformers: if args.disable_xformers:
@@ -1280,7 +1263,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b _ = a + b
synchronize() synchronize()
except RuntimeError: except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return #Dump it! We already know about it from the synchronous return
pass pass

View File

@@ -599,27 +599,6 @@ class ModelPatcher:
return models return models
def model_patches_call_function(self, function_name="cleanup", arguments={}):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], function_name):
getattr(patch_list[i], function_name)(**arguments)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], function_name):
getattr(patch_list[k], function_name)(**arguments)
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, function_name):
getattr(wrap_func, function_name)(**arguments)
def model_dtype(self): def model_dtype(self):
if hasattr(self.model, "get_dtype"): if hasattr(self.model, "get_dtype"):
return self.model.get_dtype() return self.model.get_dtype()
@@ -1083,7 +1062,6 @@ class ModelPatcher:
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def cleanup(self): def cleanup(self):
self.model_patches_call_function(function_name="cleanup")
self.clean_hooks() self.clean_hooks()
if hasattr(self.model, "current_patcher"): if hasattr(self.model, "current_patcher"):
self.model.current_patcher = None self.model.current_patcher = None

View File

@@ -954,8 +954,7 @@ class VAE:
if pixel_samples is None: if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out pixel_samples[x:x+batch_number] = out
except Exception as e: except model_management.OOM_EXCEPTION:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the #NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block. #exception and the exception itself refs them all until we get out of this except block.
@@ -1030,8 +1029,7 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out samples[x:x + batch_number] = out
except Exception as e: except model_management.OOM_EXCEPTION:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the #NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block. #exception and the exception itself refs them all until we get out of this except block.

View File

@@ -1,68 +0,0 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@@ -1,395 +0,0 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.reve import (
ReveImageCreateRequest,
ReveImageEditRequest,
ReveImageRemixRequest,
RevePostprocessingOperation,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
sync_op_raw,
tensor_to_base64_string,
validate_string,
)
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
ops = []
if upscale["upscale"] == "enabled":
ops.append(
RevePostprocessingOperation(
process="upscale",
upscale_factor=upscale["upscale_factor"],
)
)
if remove_background:
ops.append(RevePostprocessingOperation(process="remove_background"))
return ops or None
def _postprocessing_inputs():
return [
IO.DynamicCombo.Input(
"upscale",
options=[
IO.DynamicCombo.Option("disabled", []),
IO.DynamicCombo.Option(
"enabled",
[
IO.Int.Input(
"upscale_factor",
default=2,
min=2,
max=4,
step=1,
tooltip="Upscale factor (2x, 3x, or 4x).",
),
],
),
],
tooltip="Upscale the generated image. May add additional cost.",
),
IO.Boolean.Input(
"remove_background",
default=False,
tooltip="Remove the background from the generated image. May add additional cost.",
),
]
def _reve_price_extractor(headers: dict) -> float | None:
credits_used = headers.get("x-reve-credits-used")
if credits_used is not None:
return float(credits_used) / 524.48
return None
def _reve_response_header_validator(headers: dict) -> None:
error_code = headers.get("x-reve-error-code")
if error_code:
raise ValueError(f"Reve API error: {error_code}")
if headers.get("x-reve-content-violation", "").lower() == "true":
raise ValueError("The generated image was flagged for content policy violation.")
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
return [
IO.DynamicCombo.Option(
version,
[
IO.Combo.Input(
"aspect_ratio",
options=aspect_ratios,
tooltip="Aspect ratio of the output image.",
),
IO.Int.Input(
"test_time_scaling",
default=1,
min=1,
max=5,
step=1,
tooltip="Higher values produce better images but cost more credits.",
advanced=True,
),
],
)
for version in versions
]
class ReveImageCreateNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-create@20250915"],
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for generation.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/create",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageCreateRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
version=model["model"],
test_time_scaling=model["test_time_scaling"],
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
IO.String.Input(
"edit_instruction",
multiline=True,
default="",
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-edit@20250915", "reve-edit-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for editing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
edit_instruction: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(edit_instruction, min_length=1, max_length=2560)
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/edit",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageEditRequest(
edit_instruction=edit_instruction,
reference_image=tensor_to_base64_string(image),
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageRemixNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image_",
min=1,
max=6,
),
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. "
"May include XML img tags to reference specific images by index, "
"e.g. <img>0</img>, <img>1</img>, etc.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-remix@20250915", "reve-remix-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for remixing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
reference_images: IO.Autogrow.Type,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
if not reference_images:
raise ValueError("At least one reference image is required.")
ref_base64_list = []
for key in reference_images:
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
if len(ref_base64_list) > 6:
raise ValueError("Maximum 6 reference images are allowed.")
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/remix",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageRemixRequest(
prompt=prompt,
reference_images=ref_base64_list,
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ReveImageCreateNode,
ReveImageEditNode,
ReveImageRemixNode,
]
async def comfy_entrypoint() -> ReveExtension:
return ReveExtension()

View File

@@ -67,7 +67,6 @@ class _RequestConfig:
progress_origin_ts: float | None = None progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None is_rate_limited: Callable[[int, Any], bool] | None = None
response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass @dataclass
@@ -203,13 +202,11 @@ async def sync_op_raw(
monitor_progress: bool = True, monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16, max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None, is_rate_limited: Callable[[int, Any], bool] | None = None,
response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes: ) -> dict[str, Any] | bytes:
""" """
Make a single network request. Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON). - If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes. - If as_binary=True: returns bytes.
- response_header_validator: optional callback receiving response headers dict
""" """
if isinstance(data, BaseModel): if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True) data = data.model_dump(exclude_none=True)
@@ -235,7 +232,6 @@ async def sync_op_raw(
price_extractor=price_extractor, price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit, max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited, is_rate_limited=is_rate_limited,
response_header_validator=response_header_validator,
) )
return await _request_base(cfg, expect_binary=as_binary) return await _request_base(cfg, expect_binary=as_binary)
@@ -773,12 +769,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
) )
bytes_payload = bytes(buff) bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
if cfg.price_extractor:
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(resp_headers)
if cfg.response_header_validator:
cfg.response_header_validator(resp_headers)
operation_succeeded = True operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time) final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response( request_logger.log_request_response(
@@ -786,7 +776,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method, request_method=method,
request_url=url, request_url=url,
response_status_code=resp.status, response_status_code=resp.status,
response_headers=resp_headers, response_headers=dict(resp.headers),
response_content=bytes_payload, response_content=bytes_payload,
) )
return bytes_payload return bytes_payload

View File

@@ -1,32 +1,32 @@
from comfy import model_management from comfy import model_management
from comfy_api.latest import ComfyExtension, IO
from typing_extensions import override
import math import math
class LTXVLatentUpsampler:
class LTXVLatentUpsampler(IO.ComfyNode):
""" """
Upsamples a video latent by a factor of 2. Upsamples a video latent by a factor of 2.
""" """
@classmethod @classmethod
def define_schema(cls): def INPUT_TYPES(s):
return IO.Schema( return {
node_id="LTXVLatentUpsampler", "required": {
category="latent/video", "samples": ("LATENT",),
is_experimental=True, "upscale_model": ("LATENT_UPSCALE_MODEL",),
inputs=[ "vae": ("VAE",),
IO.Latent.Input("samples"), }
IO.LatentUpscaleModel.Input("upscale_model"), }
IO.Vae.Input("vae"),
],
outputs=[
IO.Latent.Output(),
],
)
@classmethod RETURN_TYPES = ("LATENT",)
def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput: FUNCTION = "upsample_latent"
CATEGORY = "latent/video"
EXPERIMENTAL = True
def upsample_latent(
self,
samples: dict,
upscale_model,
vae,
) -> tuple:
""" """
Upsample the input latent using the provided model. Upsample the input latent using the provided model.
@@ -34,6 +34,7 @@ class LTXVLatentUpsampler(IO.ComfyNode):
samples (dict): Input latent samples samples (dict): Input latent samples
upscale_model (LatentUpsampler): Loaded upscale model upscale_model (LatentUpsampler): Loaded upscale model
vae: VAE model for normalization vae: VAE model for normalization
auto_tiling (bool): Whether to automatically tile the input for processing
Returns: Returns:
tuple: Tuple containing the upsampled latent tuple: Tuple containing the upsampled latent
@@ -66,16 +67,9 @@ class LTXVLatentUpsampler(IO.ComfyNode):
return_dict = samples.copy() return_dict = samples.copy()
return_dict["samples"] = upsampled_latents return_dict["samples"] = upsampled_latents
return_dict.pop("noise_mask", None) return_dict.pop("noise_mask", None)
return IO.NodeOutput(return_dict) return (return_dict,)
upsample_latent = execute # TODO: remove
class LTXVLatentUpsamplerExtension(ComfyExtension): NODE_CLASS_MAPPINGS = {
@override "LTXVLatentUpsampler": LTXVLatentUpsampler,
async def get_node_list(self) -> list[type[IO.ComfyNode]]: }
return [LTXVLatentUpsampler]
async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension:
return LTXVLatentUpsamplerExtension()

View File

@@ -86,8 +86,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False oom = False
except Exception as e: except model_management.OOM_EXCEPTION as e:
model_management.raise_non_oom(e)
tile //= 2 tile //= 2
if tile < 128: if tile < 128:
raise e raise e

View File

@@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
tips = "" tips = ""
if comfy.model_management.is_oom(ex): if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.") logging.error("Got an OOM, unloading all loaded models.")

24
main.py
View File

@@ -3,16 +3,14 @@ comfy.options.enable_args_parsing()
import os import os
import importlib.util import importlib.util
import shutil
import importlib.metadata
import folder_paths import folder_paths
import time import time
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger from app.logger import setup_logger
from app.assets.seeder import asset_seeder
import itertools import itertools
import utils.extra_config import utils.extra_config
from utils.mime_types import init_mime_types from utils.mime_types import init_mime_types
import faulthandler
import logging import logging
import sys import sys
from comfy_execution.progress import get_progress_state from comfy_execution.progress import get_progress_state
@@ -27,8 +25,6 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
faulthandler.enable(file=sys.stderr, all_threads=False)
import comfy_aimdo.control import comfy_aimdo.control
if enables_dynamic_vram(): if enables_dynamic_vram():
@@ -68,15 +64,8 @@ if __name__ == "__main__":
def handle_comfyui_manager_unavailable(): def handle_comfyui_manager_unavailable():
manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") if not args.windows_standalone_build:
uv_available = shutil.which("uv") is not None logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
if uv_available:
msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
msg += "\n"
logging.warning(msg)
args.enable_manager = False args.enable_manager = False
@@ -184,6 +173,7 @@ execute_prestartup_script()
# Main code # Main code
import asyncio import asyncio
import shutil
import threading import threading
import gc import gc
@@ -192,7 +182,6 @@ if 'torch' in sys.modules:
import comfy.utils import comfy.utils
from app.assets.seeder import asset_seeder
import execution import execution
import server import server
@@ -462,11 +451,6 @@ if __name__ == "__main__":
# Running directly, just start ComfyUI. # Running directly, just start ComfyUI.
logging.info("Python version: {}".format(sys.version)) logging.info("Python version: {}".format(sys.version))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
for package in ("comfy-aimdo", "comfy-kitchen"):
try:
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
except:
pass
if sys.version_info.major == 3 and sys.version_info.minor < 10: if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")

View File

@@ -1 +1 @@
comfyui_manager==4.1b2 comfyui_manager==4.1b1

View File

@@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19 comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.18 comfyui-workflow-templates==0.9.11
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.3
torch torch
torchsde torchsde
@@ -23,7 +23,7 @@ SQLAlchemy
filelock filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.10 comfy-aimdo>=0.2.7
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3