Compare commits

..

1 Commits

Author SHA1 Message Date
bymyself
cffeebfa01 fix: prevent --cpu flag from allocating GPU memory
Two root causes fixed:

1. soft_empty_cache() and synchronize() in model_management.py lacked a
   cpu_state == CPUState.CPU guard. They fell through to torch.cuda calls
   that initialize a CUDA context (150-500MB VRAM) even in CPU-only mode.

2. comfy_kitchen is imported unconditionally at startup via quant_ops.py.
   The import chain triggers torch.cuda.is_available() -> cuInit, which
   initializes the CUDA driver. Now gated behind args.cpu check.

Also adds missing QuantizedLayout and register_layout_op fallback stubs
that were absent from the original ImportError handler.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbd03-433e-7601-93ff-3887227496b4
2026-03-05 12:32:46 -08:00
175 changed files with 3665 additions and 515722 deletions

View File

@@ -1,103 +0,0 @@
#!/usr/bin/env bash
# Checks pull request commits for AI agent Co-authored-by trailers.
# Exits non-zero when any are found and prints fix instructions.
set -euo pipefail
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
# Known AI coding-agent trailer patterns (case-insensitive).
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
AGENT_PATTERNS=(
# Anthropic — Claude Code / Amp
'noreply@anthropic\.com'
# Cursor
'cursoragent@cursor\.com'
# GitHub Copilot
'copilot-swe-agent\[bot\]'
'copilot@github\.com'
# OpenAI Codex
'noreply@openai\.com'
'codex@openai\.com'
# Aider
'aider@aider\.chat'
# Google — Gemini / Jules
'gemini@google\.com'
'jules@google\.com'
# Windsurf / Codeium
'@codeium\.com'
# Devin
'devin-ai-integration\[bot\]'
'devin@cognition\.ai'
'devin@cognition-labs\.com'
# Amazon Q Developer
'amazon-q-developer'
'@amazon\.com.*[Qq].[Dd]eveloper'
# Cline
'cline-bot'
'cline@cline\.ai'
# Continue
'continue-agent'
'continue@continue\.dev'
# Sourcegraph
'noreply@sourcegraph\.com'
# Generic catch-alls for common agent name patterns
'Co-authored-by:.*\b[Cc]laude\b'
'Co-authored-by:.*\b[Cc]opilot\b'
'Co-authored-by:.*\b[Cc]ursor\b'
'Co-authored-by:.*\b[Cc]odex\b'
'Co-authored-by:.*\b[Gg]emini\b'
'Co-authored-by:.*\b[Aa]ider\b'
'Co-authored-by:.*\b[Dd]evin\b'
'Co-authored-by:.*\b[Ww]indsurf\b'
'Co-authored-by:.*\b[Cc]line\b'
'Co-authored-by:.*\b[Aa]mazon Q\b'
'Co-authored-by:.*\b[Jj]ules\b'
'Co-authored-by:.*\bOpenCode\b'
)
# Build a single alternation regex from all patterns.
regex=""
for pattern in "${AGENT_PATTERNS[@]}"; do
if [[ -n "$regex" ]]; then
regex="${regex}|${pattern}"
else
regex="$pattern"
fi
done
# Collect Co-authored-by lines from every commit in the PR range.
violations=""
while IFS= read -r sha; do
message="$(git log -1 --format='%B' "$sha")"
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
if [[ -z "$matched_lines" ]]; then
continue
fi
while IFS= read -r line; do
if echo "$line" | grep -iqE "$regex"; then
short="$(git log -1 --format='%h' "$sha")"
violations="${violations} ${short}: ${line}"$'\n'
fi
done <<< "$matched_lines"
done < <(git rev-list "${base_sha}..${head_sha}")
if [[ -n "$violations" ]]; then
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
echo ""
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
echo ""
echo "$violations"
echo "These trailers should be removed before merging."
echo ""
echo "To fix, rewrite the commit messages with:"
echo " git rebase -i ${base_sha}"
echo ""
echo "and remove the Co-authored-by lines, then force-push your branch."
echo ""
echo "If you believe this is a false positive, please open an issue."
exit 1
fi
echo "No AI agent Co-authored-by trailers found."

View File

@@ -1,19 +0,0 @@
name: Check AI Co-Authors
on:
pull_request:
branches: ['*']
jobs:
check-ai-co-authors:
name: Check for AI agent co-author trailers
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check commits for AI co-author trailers
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"

View File

@@ -38,8 +38,6 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started
### Local
#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- Available on Windows & macOS.
@@ -51,13 +49,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
#### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
### Cloud
#### [Comfy Cloud](https://www.comfy.org/cloud)
- Our official paid cloud version for those who can't afford local hardware.
## Examples
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
@@ -232,7 +225,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:

View File

@@ -8,7 +8,7 @@ from alembic import context
config = context.config
from app.database.models import Base, NAMING_CONVENTION
from app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
@@ -51,10 +51,7 @@ def run_migrations_online() -> None:
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
naming_convention=NAMING_CONVENTION,
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():

View File

@@ -1,267 +0,0 @@
"""
Merge AssetInfo and AssetCacheState into unified asset_references table.
This migration drops old tables and creates the new unified schema.
All existing data is discarded.
Revision ID: 0002_merge_to_asset_references
Revises: 0001_assets
Create Date: 2025-02-11
"""
from alembic import op
import sqlalchemy as sa
revision = "0002_merge_to_asset_references"
down_revision = "0001_assets"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop old tables (order matters due to FK constraints)
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
# Truncate assets table (cascades handled by dropping dependent tables first)
op.execute("DELETE FROM assets")
# Create asset_references table
op.create_table(
"asset_references",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column(
"asset_id",
sa.String(length=36),
sa.ForeignKey("assets.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column(
"needs_verify",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
sa.Column(
"is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column(
"preview_id",
sa.String(length=36),
sa.ForeignKey("assets.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True),
sa.CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
sa.CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
)
op.create_index(
"uq_asset_references_file_path", "asset_references", ["file_path"], unique=True
)
op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"])
op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"])
op.create_index("ix_asset_references_name", "asset_references", ["name"])
op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"])
op.create_index(
"ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"]
)
op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"])
op.create_index(
"ix_asset_references_last_access_time", "asset_references", ["last_access_time"]
)
op.create_index(
"ix_asset_references_owner_name", "asset_references", ["owner_id", "name"]
)
op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"])
# Create asset_reference_tags table
op.create_table(
"asset_reference_tags",
sa.Column(
"asset_reference_id",
sa.String(length=36),
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"tag_name",
sa.String(length=512),
sa.ForeignKey("tags.name", ondelete="RESTRICT"),
nullable=False,
),
sa.Column(
"origin", sa.String(length=32), nullable=False, server_default="manual"
),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint(
"asset_reference_id", "tag_name", name="pk_asset_reference_tags"
),
)
op.create_index(
"ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"]
)
op.create_index(
"ix_asset_reference_tags_asset_reference_id",
"asset_reference_tags",
["asset_reference_id"],
)
# Create asset_reference_meta table
op.create_table(
"asset_reference_meta",
sa.Column(
"asset_reference_id",
sa.String(length=36),
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint(
"asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta"
),
)
op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"])
op.create_index(
"ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"]
)
op.create_index(
"ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"]
)
op.create_index(
"ix_asset_reference_meta_key_val_bool",
"asset_reference_meta",
["key", "val_bool"],
)
def downgrade() -> None:
"""Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema.
NOTE: Data is not recoverable. The upgrade discards all rows from the old
tables and truncates assets. After downgrade the old schema will be empty.
A filesystem rescan will repopulate data once the older code is running.
"""
# Drop new tables (order matters due to FK constraints)
op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta")
op.drop_table("asset_reference_meta")
op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags")
op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags")
op.drop_table("asset_reference_tags")
op.drop_index("ix_asset_references_deleted_at", table_name="asset_references")
op.drop_index("ix_asset_references_owner_name", table_name="asset_references")
op.drop_index("ix_asset_references_last_access_time", table_name="asset_references")
op.drop_index("ix_asset_references_created_at", table_name="asset_references")
op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references")
op.drop_index("ix_asset_references_is_missing", table_name="asset_references")
op.drop_index("ix_asset_references_name", table_name="asset_references")
op.drop_index("ix_asset_references_owner_id", table_name="asset_references")
op.drop_index("ix_asset_references_asset_id", table_name="asset_references")
op.drop_index("uq_asset_references_file_path", table_name="asset_references")
op.drop_table("asset_references")
# Truncate assets (upgrade deleted all rows; downgrade starts fresh too)
op.execute("DELETE FROM assets")
# Recreate old tables from 0001_assets schema
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])

View File

@@ -1,98 +0,0 @@
"""
Add system_metadata and job_id columns to asset_references.
Change preview_id FK from assets.id to asset_references.id.
Revision ID: 0003_add_metadata_job_id
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
from app.database.models import NAMING_CONVENTION
revision = "0003_add_metadata_job_id"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("job_id", sa.String(length=36), nullable=True)
)
# Change preview_id FK from assets.id to asset_references.id (self-ref).
# Existing values are asset-content IDs that won't match reference IDs,
# so null them out first.
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_constraint(
"fk_asset_references_preview_id_assets", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_asset_references",
"asset_references",
["preview_id"],
["id"],
ondelete="SET NULL",
)
batch_op.create_index(
"ix_asset_references_preview_id", ["preview_id"]
)
# Purge any all-null meta rows before adding the constraint
op.execute(
"DELETE FROM asset_reference_meta"
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
)
with op.batch_alter_table("asset_reference_meta") as batch_op:
batch_op.create_check_constraint(
"ck_asset_reference_meta_has_value",
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
)
def downgrade() -> None:
# SQLite doesn't reflect CHECK constraints, so we must declare it
# explicitly via table_args for the batch recreate to find it.
# Use the fully-rendered constraint name to avoid the naming convention
# doubling the prefix.
with op.batch_alter_table(
"asset_reference_meta",
table_args=[
sa.CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="ck_asset_reference_meta_has_value",
),
],
) as batch_op:
batch_op.drop_constraint(
"ck_asset_reference_meta_has_value", type_="check"
)
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_index("ix_asset_references_preview_id")
batch_op.drop_constraint(
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_assets",
"assets",
["preview_id"],
["id"],
ondelete="SET NULL",
)
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("job_id")
batch_op.drop_column("system_metadata")

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,6 @@
import json
from dataclasses import dataclass
from typing import Any, Literal
from app.assets.helpers import validate_blake3_hash
from pydantic import (
BaseModel,
ConfigDict,
@@ -12,43 +10,6 @@ from pydantic import (
model_validator,
)
class UploadError(Exception):
"""Error during upload parsing with HTTP status and code."""
def __init__(self, status: int, code: str, message: str):
super().__init__(message)
self.status = status
self.code = code
self.message = message
class AssetValidationError(Exception):
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
def __init__(self, code: str, message: str):
super().__init__(message)
self.code = code
self.message = message
@dataclass
class ParsedUpload:
"""Result of parsing a multipart upload request."""
file_present: bool
file_written: int
file_client_name: str | None
tmp_path: str | None
tags_raw: list[str]
provided_name: str | None
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
@@ -60,9 +21,7 @@ class ListAssetsQuery(BaseModel):
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"
)
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@@ -100,17 +59,11 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel):
name: str | None = None
user_metadata: dict[str, Any] | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@model_validator(mode="after")
def _validate_at_least_one_field(self):
if all(
v is None
for v in (self.name, self.user_metadata, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, preview_id."
)
def _at_least_one(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
@@ -118,20 +71,26 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str
name: str | None = None
name: str
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
return validate_blake3_hash(v or "")
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before")
@classmethod
def _normalize_tags_field(cls, v):
def _tags_norm(cls, v):
if v is None:
return []
if isinstance(v, list):
@@ -148,44 +107,6 @@ class CreateFromHashBody(BaseModel):
return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@@ -233,36 +154,38 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
- mime_type: optional MIME type override
- preview_id: optional asset_reference ID for preview
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
Files are stored using the content hash as filename stem.
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(default_factory=list)
tags: list[str] = Field(..., min_length=1)
name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None) # references an asset_reference id
@field_validator("hash", mode="before")
@classmethod
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip()
s = str(v).strip().lower()
if not s:
return None
return validate_blake3_hash(s)
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before")
@classmethod
@@ -331,13 +254,11 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after")
def _validate_order(self):
if not self.tags:
raise ValueError("at least one tag is required for uploads")
raise ValueError("tags must be provided and non-empty")
root = self.tags[0]
if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError(
"models uploads require a category tag as the second tag"
)
raise ValueError("models uploads require a category tag as the second tag")
return self

View File

@@ -4,10 +4,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class Asset(BaseModel):
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
``id`` here is the AssetReference id, not the content-addressed Asset id."""
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: str | None = None
@@ -15,33 +12,61 @@ class Asset(BaseModel):
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
job_id: str | None = None
prompt_id: str | None = None # deprecated: use job_id
created_at: datetime
updated_at: datetime
created_at: datetime | None = None
updated_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel):
assets: list[Asset]
assets: list[AssetSummary]
total: int
has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel):
name: str
count: int
@@ -66,7 +91,3 @@ class TagsRemove(BaseModel):
removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

View File

@@ -1,185 +0,0 @@
import logging
import os
import uuid
from typing import Callable
from aiohttp import web
import folder_paths
from app.assets.api.schemas_in import ParsedUpload, UploadError
from app.assets.helpers import validate_blake3_hash
def normalize_and_validate_hash(s: str) -> str:
"""Validate and normalize a hash string.
Returns canonical 'blake3:<hex>' or raises UploadError.
"""
try:
return validate_blake3_hash(s)
except ValueError:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
async def parse_multipart_upload(
request: web.Request,
check_hash_exists: Callable[[str], bool],
) -> ParsedUpload:
"""
Parse a multipart/form-data upload request.
Args:
request: The aiohttp request
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
Returns:
ParsedUpload with parsed fields and temp file path
Raises:
UploadError: On validation or I/O errors
"""
if not (request.content_type or "").lower().startswith("multipart/"):
raise UploadError(
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
)
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
raise UploadError(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
if s:
provided_hash = normalize_and_validate_hash(s)
try:
provided_hash_exists = check_hash_exists(provided_hash)
except Exception as e:
logging.exception(
"check_hash_exists failed for hash=%s: %s", provided_hash, e
)
raise UploadError(
500,
"HASH_CHECK_FAILED",
"Backend error while checking asset hash.",
)
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# Hash exists - drain file but don't write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
)
continue
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
delete_temp_file_if_exists(tmp_path)
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
)
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
elif fname == "id":
raise UploadError(
400,
"UNSUPPORTED_FIELD",
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
)
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
)
if (
file_present
and file_written == 0
and not (provided_hash and provided_hash_exists)
):
delete_temp_file_if_exists(tmp_path)
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
return ParsedUpload(
file_present=file_present,
file_written=file_written,
file_client_name=file_client_name,
tmp_path=tmp_path,
tags_raw=tags_raw,
provided_name=provided_name,
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
)
def delete_temp_file_if_exists(tmp_path: str | None) -> None:
"""Safely remove a temp file and its parent directory if empty."""
if tmp_path:
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except OSError as e:
logging.debug("Failed to delete temp file %s: %s", tmp_path, e)
try:
parent = os.path.dirname(tmp_path)
if parent and os.path.isdir(parent):
os.rmdir(parent) # only succeeds if empty
except OSError:
pass

View File

@@ -0,0 +1,204 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
@@ -16,36 +16,47 @@ from sqlalchemy import (
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import get_utc_now
from app.database.models import Base
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
DateTime(timezone=False), nullable=False, default=utcnow
)
references: Mapped[list[AssetReference]] = relationship(
"AssetReference",
infos: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id),
foreign_keys=lambda: [AssetReference.asset_id],
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
# preview_id on AssetReference is a self-referential FK to asset_references.id
preview_of: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list[AssetCacheState]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
@@ -53,126 +64,108 @@ class Asset(Base):
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetReference(Base):
"""Unified model combining file cache state and user-facing metadata.
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
Each row represents either:
- A filesystem reference (file_path is set) with cache state
- An API-created reference (file_path is NULL) without cache state
"""
__tablename__ = "asset_references"
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
)
# Cache state fields (from former AssetCacheState)
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# Info fields (from former AssetInfo)
asset: Mapped[Asset] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
)
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
deleted_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=False), nullable=True, default=None
)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="references",
back_populates="infos",
foreign_keys=[asset_id],
lazy="selectin",
)
preview_ref: Mapped[AssetReference | None] = relationship(
"AssetReference",
preview_asset: Mapped[Asset | None] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
remote_side=lambda: [AssetReference.id],
)
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
back_populates="asset_reference",
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="asset_reference",
tag_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_references",
overlaps="tags,asset_infos",
)
tags: Mapped[list[Tag]] = relationship(
secondary="asset_reference_tags",
back_populates="asset_references",
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_reference_links,asset_references,tag",
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
Index("uq_asset_references_file_path", "file_path", unique=True),
Index("ix_asset_references_asset_id", "asset_id"),
Index("ix_asset_references_owner_id", "owner_id"),
Index("ix_asset_references_name", "name"),
Index("ix_asset_references_is_missing", "is_missing"),
Index("ix_asset_references_enrichment_level", "enrichment_level"),
Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_deleted_at", "deleted_at"),
Index("ix_asset_references_preview_id", "preview_id"),
Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
path_part = f" path={self.file_path!r}" if self.file_path else ""
return f"<AssetReference id={self.id} name={self.name!r}{path_part}>"
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetReferenceMeta(Base):
__tablename__ = "asset_reference_meta"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_reference_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
@@ -182,44 +175,36 @@ class AssetReferenceMeta(Base):
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
asset_reference: Mapped[AssetReference] = relationship(
back_populates="metadata_entries"
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_reference_meta_key", "key"),
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="has_value",
),
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetReferenceTag(Base):
__tablename__ = "asset_reference_tags"
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_reference_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_reference_links")
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_reference_tags_tag_name", "tag_name"),
Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"),
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
@@ -229,18 +214,20 @@ class Tag(Base):
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="tag",
overlaps="asset_references,tags",
overlaps="asset_infos,tags",
)
asset_references: Mapped[list[AssetReference]] = relationship(
secondary="asset_reference_tags",
asset_infos: Mapped[list[AssetInfo]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_reference_links,tag_links,tags,asset_reference",
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@@ -0,0 +1,976 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@@ -1,137 +0,0 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
create_stub_asset,
get_asset_by_hash,
get_existing_asset_ids,
reassign_asset_references,
update_asset_hash_and_mime,
upsert_asset,
)
from app.assets.database.queries.asset_reference import (
CacheStateRow,
UnenrichedReferenceRow,
bulk_insert_references_ignore_conflicts,
bulk_update_enrichment_level,
count_active_siblings,
bulk_update_is_missing,
bulk_update_needs_verify,
convert_metadata_to_rows,
delete_assets_by_ids,
delete_orphaned_seed_asset,
delete_reference_by_id,
delete_references_by_ids,
fetch_reference_and_asset,
fetch_reference_asset_and_tags,
get_or_create_reference,
get_reference_by_file_path,
get_reference_by_id,
get_reference_with_owner_check,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_references_for_prefixes,
get_unenriched_references,
get_unreferenced_unhashed_asset_ids,
insert_reference,
list_all_file_paths_by_asset_id,
list_references_by_asset_id,
list_references_page,
mark_references_missing_outside_prefixes,
rebuild_metadata_projection,
reference_exists,
reference_exists_for_asset_id,
restore_references_by_paths,
set_reference_metadata,
set_reference_preview,
set_reference_system_metadata,
soft_delete_reference_by_id,
update_reference_access_time,
update_reference_name,
update_is_missing_by_asset_id,
update_reference_timestamps,
update_reference_updated_at,
upsert_reference,
)
from app.assets.database.queries.tags import (
AddTagsResult,
RemoveTagsResult,
SetTagsResult,
add_missing_tag_for_asset_id,
add_tags_to_reference,
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
set_reference_tags,
validate_tags_exist,
)
__all__ = [
"AddTagsResult",
"CacheStateRow",
"RemoveTagsResult",
"SetTagsResult",
"UnenrichedReferenceRow",
"add_missing_tag_for_asset_id",
"add_tags_to_reference",
"asset_exists_by_hash",
"bulk_insert_assets",
"bulk_insert_references_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_update_enrichment_level",
"count_active_siblings",
"create_stub_asset",
"bulk_update_is_missing",
"bulk_update_needs_verify",
"convert_metadata_to_rows",
"delete_assets_by_ids",
"delete_orphaned_seed_asset",
"delete_reference_by_id",
"delete_references_by_ids",
"ensure_tags_exist",
"fetch_reference_and_asset",
"fetch_reference_asset_and_tags",
"get_asset_by_hash",
"get_existing_asset_ids",
"get_or_create_reference",
"get_reference_by_file_path",
"get_reference_by_id",
"get_reference_with_owner_check",
"get_reference_ids_by_ids",
"get_reference_tags",
"get_references_by_paths_and_asset_ids",
"get_references_for_prefixes",
"get_unenriched_references",
"get_unreferenced_unhashed_asset_ids",
"insert_reference",
"list_all_file_paths_by_asset_id",
"list_references_by_asset_id",
"list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",
"rebuild_metadata_projection",
"reference_exists",
"reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id",
"remove_tags_from_reference",
"restore_references_by_paths",
"set_reference_metadata",
"set_reference_preview",
"set_reference_system_metadata",
"soft_delete_reference_by_id",
"set_reference_tags",
"update_asset_hash_and_mime",
"update_is_missing_by_asset_id",
"update_reference_access_time",
"update_reference_name",
"update_reference_timestamps",
"update_reference_updated_at",
"upsert_asset",
"upsert_reference",
"validate_tags_exist",
]

View File

@@ -1,152 +0,0 @@
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
def asset_exists_by_hash(
session: Session,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True))
.select_from(Asset)
.where(Asset.hash == asset_hash)
.limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
asset_hash: str,
) -> Asset | None:
return (
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
.scalars()
.first()
)
def upsert_asset(
session: Session,
asset_hash: str,
size_bytes: int,
mime_type: str | None = None,
) -> tuple[Asset, bool, bool]:
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
if mime_type:
vals["mime_type"] = mime_type
ins = (
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
.scalars()
.first()
)
if not asset:
raise RuntimeError("Asset row not found after upsert.")
updated = False
if not created:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and not asset.mime_type:
asset.mime_type = mime_type
changed = True
if changed:
updated = True
return asset, created, updated
def create_stub_asset(
session: Session,
size_bytes: int,
mime_type: str | None = None,
) -> Asset:
"""Create a new asset with no hash (stub for later enrichment)."""
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
session.add(asset)
session.flush()
return asset
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
if not rows:
return
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
session.execute(ins, chunk)
def get_existing_asset_ids(
session: Session,
asset_ids: list[str],
) -> set[str]:
"""Return the subset of asset_ids that exist in the database."""
if not asset_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
rows = session.execute(
select(Asset.id).where(Asset.id.in_(chunk))
).fetchall()
found.update(row[0] for row in rows)
return found
def update_asset_hash_and_mime(
session: Session,
asset_id: str,
asset_hash: str | None = None,
mime_type: str | None = None,
) -> bool:
"""Update asset hash and/or mime_type. Returns True if asset was found."""
asset = session.get(Asset, asset_id)
if not asset:
return False
if asset_hash is not None:
asset.hash = asset_hash
if mime_type is not None and not asset.mime_type:
asset.mime_type = mime_type
return True
def reassign_asset_references(
session: Session,
from_asset_id: str,
to_asset_id: str,
reference_id: str,
) -> None:
"""Reassign a reference from one asset to another.
Used when merging a stub asset into an existing asset with the same hash.
"""
ref = session.get(AssetReference, reference_id)
if ref and ref.asset_id == from_asset_id:
ref.asset_id = to_asset_id
session.flush()

File diff suppressed because it is too large Load Diff

View File

@@ -1,127 +0,0 @@
"""Shared utilities for database query modules."""
import os
from decimal import Decimal
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy import exists
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
from app.assets.helpers import escape_sql_like_string, normalize_tags
MAX_BIND_PARAMS = 800
def calculate_rows_per_statement(cols: int) -> int:
"""Calculate how many rows can fit in one statement given column count."""
return max(1, MAX_BIND_PARAMS // max(1, cols))
def iter_chunks(seq, n: int):
"""Yield successive n-sized chunks from seq."""
for i in range(0, len(seq), n):
yield seq[i : i + n]
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
"""Yield chunks of rows sized to fit within bind param limits."""
if not rows:
return
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads.
Owner-less rows are visible to everyone.
"""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetReference.owner_id == ""
return AssetReference.owner_id.in_(["", owner_id])
def build_prefix_like_conditions(
prefixes: list[str],
) -> list[sa.sql.ColumnElement]:
"""Build LIKE conditions for matching file paths under directory prefixes."""
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
return sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@@ -1,418 +0,0 @@
from dataclasses import dataclass
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import (
Asset,
AssetReference,
AssetReferenceMeta,
AssetReferenceTag,
Tag,
)
from app.assets.database.queries.common import (
apply_metadata_filter,
apply_tag_filters,
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
@dataclass(frozen=True)
class AddTagsResult:
added: list[str]
already_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class RemoveTagsResult:
removed: list[str]
not_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class SetTagsResult:
added: list[str]
removed: list[str]
total: list[str]
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_reference_tags(session: Session, reference_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
session.execute(
select(AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id == reference_id)
.order_by(AssetReferenceTag.tag_name.asc())
)
).all()
]
def set_reference_tags(
session: Session,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsResult:
desired = normalize_tags(tags)
current = set(get_reference_tags(session, reference_id))
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
if to_remove:
session.execute(
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
def add_tags_to_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
reference_row: AssetReference | None = None,
) -> AddTagsResult:
if not reference_row:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return AddTagsResult(added=[], already_present=[], total_tags=total)
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = set(get_reference_tags(session, reference_id))
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_reference_tags(session, reference_id=reference_id))
return AddTagsResult(
added=sorted(((after - current) & want)),
already_present=sorted(want & current),
total_tags=sorted(after),
)
def remove_tags_from_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
) -> RemoveTagsResult:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return RemoveTagsResult(removed=[], not_present=[], total_tags=total)
existing = set(get_reference_tags(session, reference_id))
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_reference_tags(session, reference_id=reference_id)
return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total)
def add_missing_tag_for_asset_id(
session: Session,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetReference.id.label("asset_reference_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(get_utc_now()).label("added_at"),
)
.where(AssetReference.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == "missing")
)
)
)
)
session.execute(
sqlite.insert(AssetReferenceTag)
.from_select(
["asset_reference_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
)
def remove_missing_tag_for_asset_id(
session: Session,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id.in_(
sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id)
),
AssetReferenceTag.tag_name == "missing",
)
)
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetReferenceTag.tag_name.label("tag_name"),
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
visible_tags_sq = (
select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
)
total_q = total_q.where(Tag.name.in_(visible_tags_sq))
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_reference_tags and asset_reference_meta.
Uses ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at
meta_rows: Dicts with: asset_reference_id, key, ordinal, val_*
"""
if tag_rows:
ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing(
index_elements=[
AssetReferenceMeta.asset_reference_id,
AssetReferenceMeta.key,
AssetReferenceMeta.ordinal,
]
)
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
session.execute(ins_meta, chunk)

View File

@@ -0,0 +1,62 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

75
app/assets/hashing.py Normal file
View File

@@ -0,0 +1,75 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@@ -1,42 +1,226 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from typing import Sequence
from pathlib import Path
from typing import Literal, Any
import folder_paths
def select_best_live_path(states: Sequence) -> str:
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
"""
alive = [
s
for s in states
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
query_dict = {
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
for key in request.query.keys()
}
return query_dict
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char in a LIKE prefix.
def prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
Returns (escaped_prefix, escape_char).
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def get_utc_now() -> datetime:
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
@@ -44,22 +228,85 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
- Stripping whitespace and converting to lowercase.
- Removing duplicates.
"""
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def validate_blake3_hash(s: str) -> str:
"""Validate and normalize a blake3 hash string.
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
Returns canonical 'blake3:<hex>' or raises ValueError.
def project_kv(key: str, value):
"""
s = s.strip().lower()
if not s or ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if (
algo != "blake3"
or len(digest) != 64
or any(c for c in digest if c not in "0123456789abcdef")
):
raise ValueError("hash must be 'blake3:<hex>'")
return f"{algo}:{digest}"
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

516
app/assets/manager.py Normal file
View File

@@ -0,0 +1,516 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

View File

@@ -1,582 +1,263 @@
import contextlib
import time
import logging
import os
from pathlib import Path
from typing import Callable, Literal, TypedDict
import sqlalchemy
import folder_paths
from app.assets.database.queries import (
add_missing_tag_for_asset_id,
bulk_update_enrichment_level,
bulk_update_is_missing,
bulk_update_needs_verify,
delete_orphaned_seed_asset,
delete_references_by_ids,
ensure_tags_exist,
get_asset_by_hash,
get_reference_by_id,
get_references_for_prefixes,
get_unenriched_references,
mark_references_missing_outside_prefixes,
reassign_asset_references,
remove_missing_tag_for_asset_id,
set_reference_system_metadata,
update_asset_hash_and_mime,
from app.database.db import create_session, dependencies_available
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
RootType
)
from app.assets.services.bulk_ingest import (
SeedAssetSpec,
batch_insert_seed_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
is_visible,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
)
from app.database.db import create_session
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
class _RefInfo(TypedDict):
ref_id: str
file_path: str
exists: bool
stat_unchanged: bool
needs_verify: bool
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
"""
Scan the given roots and seed the assets into the database.
"""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
class _AssetAccumulator(TypedDict):
hash: str | None
size_db: int
refs: list[_RefInfo]
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
RootType = Literal["models", "input", "output"]
def get_prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def get_all_known_prefixes() -> list[str]:
"""Get all known asset prefixes across all root types."""
all_roots: tuple[RootType, ...] = ("models", "input", "output")
return [p for root in all_roots for p in get_prefixes_for_root(root)]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
if not all(is_visible(part) for part in Path(rel_path).parts):
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
continue
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
abs_p = Path(abs_path)
for b in bases:
if abs_p.is_relative_to(os.path.abspath(b)):
allowed = True
break
if allowed:
out.append(abs_path)
return out
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
for t in tags:
tag_pool.add(t)
# if no file specs, nothing to do
if not specs:
return
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
def sync_references_with_filesystem(
session,
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Reconcile asset references with filesystem for a root.
- Toggle needs_verify per reference using mtime/size stat check
- For hashed assets with at least one stat-unchanged ref: delete stale missing refs
- For seed assets with all refs missing: delete Asset and its references
- Optionally add/remove 'missing' tags based on stat check in this root
- Optionally return surviving absolute paths
Args:
session: Database session
root: Root type to scan
collect_existing_paths: If True, return set of surviving file paths
update_missing_tags: If True, update 'missing' tags based on file status
Returns:
Set of surviving absolute paths if collect_existing_paths=True, else None
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = get_prefixes_for_root(root)
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
rows = get_references_for_prefixes(
session, prefixes, include_missing=update_missing_tags
)
by_asset: dict[str, _AssetAccumulator] = {}
for row in rows:
acc = by_asset.get(row.asset_id)
if acc is None:
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
by_asset[row.asset_id] = acc
stat_unchanged = False
try:
exists = True
stat_unchanged = verify_file_unchanged(
mtime_db=row.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(row.file_path, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except PermissionError:
exists = True
logging.debug("Permission denied accessing %s", row.file_path)
except OSError as e:
exists = False
logging.debug("OSError checking %s: %s", row.file_path, e)
acc["refs"].append(
{
"ref_id": row.reference_id,
"file_path": row.file_path,
"exists": exists,
"stat_unchanged": stat_unchanged,
"needs_verify": row.needs_verify,
}
)
to_set_verify: list[str] = []
to_clear_verify: list[str] = []
stale_ref_ids: list[str] = []
to_mark_missing: list[str] = []
to_clear_missing: list[str] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
refs = acc["refs"]
any_unchanged = any(r["stat_unchanged"] for r in refs)
all_missing = all(not r["exists"] for r in refs)
for r in refs:
if not r["exists"]:
to_mark_missing.append(r["ref_id"])
continue
if r["stat_unchanged"]:
to_clear_missing.append(r["ref_id"])
if r["needs_verify"]:
to_clear_verify.append(r["ref_id"])
if not r["stat_unchanged"] and not r["needs_verify"]:
to_set_verify.append(r["ref_id"])
if a_hash is None:
if refs and all_missing:
delete_orphaned_seed_asset(session, aid)
else:
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["file_path"]))
continue
if any_unchanged:
for r in refs:
if not r["exists"]:
stale_ref_ids.append(r["ref_id"])
if update_missing_tags:
try:
remove_missing_tag_for_asset_id(session, asset_id=aid)
except Exception as e:
logging.warning(
"Failed to remove missing tag for asset %s: %s", aid, e
)
elif update_missing_tags:
try:
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
except Exception as e:
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["file_path"]))
delete_references_by_ids(session, stale_ref_ids)
stale_set = set(stale_ref_ids)
to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set]
bulk_update_is_missing(session, to_mark_missing, value=True)
bulk_update_is_missing(session, to_clear_missing, value=False)
bulk_update_needs_verify(session, to_set_verify, value=True)
bulk_update_needs_verify(session, to_clear_verify, value=False)
return survivors if collect_existing_paths else None
def sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's references with the filesystem.
Returns survivors (existing paths) or empty set on failure.
"""
try:
with create_session() as sess:
survivors = sync_references_with_filesystem(
sess,
root,
collect_existing_paths=True,
update_missing_tags=True,
)
sess.commit()
return survivors or set()
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", root, e)
return set()
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
"""Mark references as missing when outside the given prefixes.
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
"""
try:
with create_session() as sess:
count = mark_references_missing_outside_prefixes(sess, prefixes)
sess.commit()
return count
except Exception as e:
logging.exception("marking missing assets failed: %s", e)
return 0
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
"""Collect all file paths for the given roots."""
paths: list[str] = []
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
return paths
def build_asset_specs(
paths: list[str],
existing_paths: set[str],
enable_metadata_extraction: bool = True,
compute_hashes: bool = False,
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
enable_metadata_extraction: If True, extract tier 1 & 2 metadata
compute_hashes: If True, compute blake3 hashes (slow for large files)
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=True)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
metadata = None
if enable_metadata_extraction:
metadata = extract_file_metadata(
abs_p,
stat_result=stat_p,
relative_filename=rel_fname,
)
# Compute hash if requested
asset_hash: str | None = None
if compute_hashes:
try:
digest, _ = compute_blake3_hash(abs_p)
asset_hash = "blake3:" + digest
except Exception as e:
logging.warning("Failed to hash %s: %s", abs_p, e)
mime_type = metadata.content_type if metadata else None
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": rel_fname,
"metadata": metadata,
"hash": asset_hash,
"mime_type": mime_type,
"job_id": None,
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created refs."""
if not specs:
return 0
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
sess.commit()
return result.inserted_refs
# Enrichment level constants
ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only
ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type)
ENRICHMENT_HASHED = 2 # Hash computed (blake3)
def get_unenriched_assets_for_roots(
roots: tuple[RootType, ...],
max_level: int = ENRICHMENT_STUB,
limit: int = 1000,
) -> list:
"""Get assets that need enrichment for the given roots.
Args:
roots: Tuple of root types to scan
max_level: Maximum enrichment level to include
limit: Maximum number of rows to return
Returns:
List of UnenrichedReferenceRow
"""
prefixes: list[str] = []
for root in roots:
prefixes.extend(get_prefixes_for_root(root))
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
return get_unenriched_references(
sess, prefixes, max_level=max_level, limit=limit
)
def enrich_asset(
session,
file_path: str,
reference_id: str,
asset_id: str,
extract_metadata: bool = True,
compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> int:
"""Enrich a single asset with metadata and/or hash.
Args:
session: Database session (caller manages lifecycle)
file_path: Absolute path to the file
reference_id: ID of the reference to update
asset_id: ID of the asset to update (for mime_type and hash)
extract_metadata: If True, extract safetensors header and mime type
compute_hash: If True, compute blake3 hash
interrupt_check: Optional non-blocking callable that returns True if
the operation should be interrupted (e.g. paused or cancelled)
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns:
New enrichment level achieved
"""
new_level = ENRICHMENT_STUB
try:
stat_p = os.stat(file_path, follow_symlinks=True)
except OSError:
return new_level
initial_mtime_ns = get_mtime_ns(stat_p)
rel_fname = compute_relative_filename(file_path)
mime_type: str | None = None
metadata = None
if extract_metadata:
metadata = extract_file_metadata(
file_path,
stat_result=stat_p,
relative_filename=rel_fname,
)
if metadata:
mime_type = metadata.content_type
new_level = ENRICHMENT_METADATA
full_hash: str | None = None
if compute_hash:
try:
mtime_before = get_mtime_ns(stat_p)
size_before = stat_p.st_size
# Restore checkpoint if available and file unchanged
checkpoint = None
if hash_checkpoints is not None:
checkpoint = hash_checkpoints.get(file_path)
if checkpoint is not None:
cur_stat = os.stat(file_path, follow_symlinks=True)
if (checkpoint.mtime_ns != get_mtime_ns(cur_stat)
or checkpoint.file_size != cur_stat.st_size):
checkpoint = None
hash_checkpoints.pop(file_path, None)
else:
mtime_before = get_mtime_ns(cur_stat)
digest, new_checkpoint = compute_blake3_hash(
file_path,
interrupt_check=interrupt_check,
checkpoint=checkpoint,
)
if digest is None:
# Interrupted — save checkpoint for later resumption
if hash_checkpoints is not None and new_checkpoint is not None:
new_checkpoint.mtime_ns = mtime_before
new_checkpoint.file_size = size_before
hash_checkpoints[file_path] = new_checkpoint
return new_level
# Completed — clear any saved checkpoint
if hash_checkpoints is not None:
hash_checkpoints.pop(file_path, None)
stat_after = os.stat(file_path, follow_symlinks=True)
mtime_after = get_mtime_ns(stat_after)
if mtime_before != mtime_after:
logging.warning("File modified during hashing, discarding hash: %s", file_path)
else:
full_hash = f"blake3:{digest}"
metadata_ok = not extract_metadata or metadata is not None
if metadata_ok:
new_level = ENRICHMENT_HASHED
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)
# Optimistic guard: if the reference's mtime_ns changed since we
# started (e.g. ingest_existing_file updated it), our results are
# stale — discard them to avoid overwriting fresh registration data.
ref = get_reference_by_id(session, reference_id)
if ref is None or ref.mtime_ns != initial_mtime_ns:
session.rollback()
logging.info(
"Ref %s mtime changed during enrichment, discarding stale result",
reference_id,
)
return ENRICHMENT_STUB
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash:
existing = get_asset_by_hash(session, full_hash)
if existing and existing.id != asset_id:
reassign_asset_references(session, asset_id, existing.id, reference_id)
delete_orphaned_seed_asset(session, asset_id)
if mime_type:
update_asset_hash_and_mime(session, existing.id, mime_type=mime_type)
else:
update_asset_hash_and_mime(session, asset_id, full_hash, mime_type)
elif mime_type:
update_asset_hash_and_mime(session, asset_id, mime_type=mime_type)
bulk_update_enrichment_level(session, [reference_id], new_level)
session.commit()
return new_level
def enrich_assets_batch(
rows: list,
extract_metadata: bool = True,
compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> tuple[int, list[str]]:
"""Enrich a batch of assets.
Uses a single DB session for the entire batch, committing after each
individual asset to avoid long-held transactions while eliminating
per-asset session creation overhead.
Args:
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
extract_metadata: If True, extract metadata for each asset
compute_hash: If True, compute hash for each asset
interrupt_check: Optional non-blocking callable that returns True if
the operation should be interrupted (e.g. paused or cancelled)
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns:
Tuple of (enriched_count, failed_reference_ids)
"""
enriched = 0
failed_ids: list[str] = []
with create_session() as sess:
for row in rows:
if interrupt_check is not None and interrupt_check():
break
try:
new_level = enrich_asset(
sess,
file_path=row.file_path,
reference_id=row.reference_id,
asset_id=row.asset_id,
extract_metadata=extract_metadata,
compute_hash=compute_hash,
interrupt_check=interrupt_check,
hash_checkpoints=hash_checkpoints,
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
if new_level > row.enrichment_level:
enriched += 1
else:
failed_ids.append(row.reference_id)
except Exception as e:
logging.warning("Failed to enrich %s: %s", row.file_path, e)
sess.rollback()
failed_ids.append(row.reference_id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
return enriched, failed_ids
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None

View File

@@ -1,846 +0,0 @@
"""Background asset seeder with thread management and cancellation support."""
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable
from app.assets.scanner import (
ENRICHMENT_METADATA,
ENRICHMENT_STUB,
RootType,
build_asset_specs,
collect_paths_for_roots,
enrich_assets_batch,
get_all_known_prefixes,
get_prefixes_for_root,
get_unenriched_assets_for_roots,
insert_asset_specs,
mark_missing_outside_prefixes_safely,
sync_root_safely,
)
from app.database.db import dependencies_available
class ScanInProgressError(Exception):
"""Raised when an operation cannot proceed because a scan is running."""
class State(Enum):
"""Seeder state machine states."""
IDLE = "IDLE"
RUNNING = "RUNNING"
PAUSED = "PAUSED"
CANCELLING = "CANCELLING"
class ScanPhase(Enum):
"""Scan phase options."""
FAST = "fast" # Phase 1: filesystem only (stubs)
ENRICH = "enrich" # Phase 2: metadata + hash
FULL = "full" # Both phases sequentially
@dataclass
class Progress:
"""Progress information for a scan operation."""
scanned: int = 0
total: int = 0
created: int = 0
skipped: int = 0
@dataclass
class ScanStatus:
"""Current status of the asset seeder."""
state: State
progress: Progress | None
errors: list[str] = field(default_factory=list)
ProgressCallback = Callable[[Progress], None]
class _AssetSeeder:
"""Background asset scanning manager.
Spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
Use the module-level ``asset_seeder`` instance.
"""
def __init__(self) -> None:
# RLock is required because _run_scan() drains pending work while
# holding _lock and re-enters start() which also acquires _lock.
self._lock = threading.RLock()
self._state = State.IDLE
self._progress: Progress | None = None
self._last_progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._run_gate = threading.Event()
self._run_gate.set() # Start unpaused (set = running, clear = paused)
self._roots: tuple[RootType, ...] = ()
self._phase: ScanPhase = ScanPhase.FULL
self._compute_hashes: bool = False
self._prune_first: bool = False
self._progress_callback: ProgressCallback | None = None
self._disabled: bool = False
self._pending_enrich: dict | None = None
def disable(self) -> None:
"""Disable the asset seeder, preventing any scans from starting."""
self._disabled = True
logging.info("Asset seeder disabled")
def is_disabled(self) -> bool:
"""Check if the asset seeder is disabled."""
return self._disabled
def start(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
phase: ScanPhase = ScanPhase.FULL,
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
compute_hashes: bool = False,
) -> bool:
"""Start a background scan for the given roots.
Args:
roots: Tuple of root types to scan (models, input, output)
phase: Scan phase to run (FAST, ENRICH, or FULL for both)
progress_callback: Optional callback called with progress updates
prune_first: If True, prune orphaned assets before scanning
compute_hashes: If True, compute blake3 hashes (slow)
Returns:
True if scan was started, False if already running
"""
if self._disabled:
logging.debug("Asset seeder is disabled, skipping start")
return False
logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value)
with self._lock:
if self._state != State.IDLE:
logging.info("Asset seeder already running, skipping start")
return False
self._state = State.RUNNING
self._progress = Progress()
self._errors = []
self._roots = roots
self._phase = phase
self._prune_first = prune_first
self._compute_hashes = compute_hashes
self._progress_callback = progress_callback
self._cancel_event.clear()
self._run_gate.set() # Ensure unpaused when starting
self._thread = threading.Thread(
target=self._run_scan,
name="_AssetSeeder",
daemon=True,
)
self._thread.start()
return True
def start_fast(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
) -> bool:
"""Start a fast scan (phase 1 only) - creates stub records.
Args:
roots: Tuple of root types to scan
progress_callback: Optional callback for progress updates
prune_first: If True, prune orphaned assets before scanning
Returns:
True if scan was started, False if already running
"""
return self.start(
roots=roots,
phase=ScanPhase.FAST,
progress_callback=progress_callback,
prune_first=prune_first,
compute_hashes=False,
)
def start_enrich(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
compute_hashes: bool = False,
) -> bool:
"""Start an enrichment scan (phase 2 only) - extracts metadata and hashes.
Args:
roots: Tuple of root types to scan
progress_callback: Optional callback for progress updates
compute_hashes: If True, compute blake3 hashes
Returns:
True if scan was started, False if already running
"""
return self.start(
roots=roots,
phase=ScanPhase.ENRICH,
progress_callback=progress_callback,
prune_first=False,
compute_hashes=compute_hashes,
)
def enqueue_enrich(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
compute_hashes: bool = False,
) -> bool:
"""Start an enrichment scan now, or queue it for after the current scan.
If the seeder is idle, starts immediately. Otherwise, the enrich
request is stored and will run automatically when the current scan
finishes.
Args:
roots: Tuple of root types to scan
compute_hashes: If True, compute blake3 hashes
Returns:
True if started immediately, False if queued for later
"""
with self._lock:
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
return True
if self._pending_enrich is not None:
existing_roots = set(self._pending_enrich["roots"])
existing_roots.update(roots)
self._pending_enrich["roots"] = tuple(existing_roots)
self._pending_enrich["compute_hashes"] = (
self._pending_enrich["compute_hashes"] or compute_hashes
)
else:
self._pending_enrich = {
"roots": roots,
"compute_hashes": compute_hashes,
}
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
return False
def cancel(self) -> bool:
"""Request cancellation of the current scan.
Returns:
True if cancellation was requested, False if not running or paused
"""
with self._lock:
if self._state not in (State.RUNNING, State.PAUSED):
return False
logging.info("Asset seeder cancelling (was %s)", self._state.value)
self._state = State.CANCELLING
self._cancel_event.set()
self._run_gate.set() # Unblock if paused so thread can exit
return True
def stop(self) -> bool:
"""Stop the current scan (alias for cancel).
Returns:
True if stop was requested, False if not running
"""
return self.cancel()
def pause(self) -> bool:
"""Pause the current scan.
The scan will complete its current batch before pausing.
Returns:
True if pause was requested, False if not running
"""
with self._lock:
if self._state != State.RUNNING:
return False
logging.info("Asset seeder pausing")
self._state = State.PAUSED
self._run_gate.clear()
return True
def resume(self) -> bool:
"""Resume a paused scan.
This is a noop if the scan is not in the PAUSED state
Returns:
True if resumed, False if not paused
"""
with self._lock:
if self._state != State.PAUSED:
return False
logging.info("Asset seeder resuming")
self._state = State.RUNNING
self._run_gate.set()
self._emit_event("assets.seed.resumed", {})
return True
def restart(
self,
roots: tuple[RootType, ...] | None = None,
phase: ScanPhase | None = None,
progress_callback: ProgressCallback | None = None,
prune_first: bool | None = None,
compute_hashes: bool | None = None,
timeout: float = 5.0,
) -> bool:
"""Cancel any running scan and start a new one.
Args:
roots: Roots to scan (defaults to previous roots)
phase: Scan phase (defaults to previous phase)
progress_callback: Progress callback (defaults to previous)
prune_first: Prune before scan (defaults to previous)
compute_hashes: Compute hashes (defaults to previous)
timeout: Max seconds to wait for current scan to stop
Returns:
True if new scan was started, False if failed to stop previous
"""
logging.info("Asset seeder restart requested")
with self._lock:
prev_roots = self._roots
prev_phase = self._phase
prev_callback = self._progress_callback
prev_prune = self._prune_first
prev_hashes = self._compute_hashes
self.cancel()
if not self.wait(timeout=timeout):
return False
cb = progress_callback if progress_callback is not None else prev_callback
return self.start(
roots=roots if roots is not None else prev_roots,
phase=phase if phase is not None else prev_phase,
progress_callback=cb,
prune_first=prune_first if prune_first is not None else prev_prune,
compute_hashes=(
compute_hashes if compute_hashes is not None else prev_hashes
),
)
def wait(self, timeout: float | None = None) -> bool:
"""Wait for the current scan to complete.
Args:
timeout: Maximum seconds to wait, or None for no timeout
Returns:
True if scan completed, False if timeout expired or no scan running
"""
with self._lock:
thread = self._thread
if thread is None:
return True
thread.join(timeout=timeout)
return not thread.is_alive()
def get_status(self) -> ScanStatus:
"""Get the current status and progress of the seeder."""
with self._lock:
src = self._progress or self._last_progress
return ScanStatus(
state=self._state,
progress=Progress(
scanned=src.scanned,
total=src.total,
created=src.created,
skipped=src.skipped,
)
if src
else None,
errors=list(self._errors),
)
def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown: cancel any running scan and wait for thread.
Args:
timeout: Maximum seconds to wait for thread to exit
"""
self.cancel()
self.wait(timeout=timeout)
with self._lock:
self._thread = None
def mark_missing_outside_prefixes(self) -> int:
"""Mark references as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their
metadata are preserved, but references are flagged as missing.
They can be restored if the file reappears in a future scan.
This operation is decoupled from scanning to prevent partial scans
from accidentally marking assets belonging to other roots.
Should be called explicitly when cleanup is desired, typically after
a full scan of all roots or during maintenance.
Returns:
Number of references marked as missing
Raises:
ScanInProgressError: If a scan is currently running
"""
with self._lock:
if self._state != State.IDLE:
raise ScanInProgressError(
"Cannot mark missing assets while scan is running"
)
self._state = State.RUNNING
try:
if not dependencies_available():
logging.warning(
"Database dependencies not available, skipping mark missing"
)
return 0
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d references as missing", marked)
return marked
finally:
with self._lock:
self._reset_to_idle()
def _reset_to_idle(self) -> None:
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
return self._cancel_event.is_set()
def _is_paused_or_cancelled(self) -> bool:
"""Non-blocking check: True if paused or cancelled.
Use as interrupt_check for I/O-bound work (e.g. hashing) so that
file handles are released immediately on pause rather than held
open while blocked. The caller is responsible for blocking on
_check_pause_and_cancel() afterward.
"""
return not self._run_gate.is_set() or self._cancel_event.is_set()
def _check_pause_and_cancel(self) -> bool:
"""Block while paused, then check if cancelled.
Call this at checkpoint locations in scan loops. It will:
1. Block indefinitely while paused (until resume or cancel)
2. Return True if cancelled, False to continue
Returns:
True if scan should stop, False to continue
"""
if not self._run_gate.is_set():
self._emit_event("assets.seed.paused", {})
self._run_gate.wait() # Blocks if paused
return self._is_cancelled()
def _emit_event(self, event_type: str, data: dict) -> None:
"""Emit a WebSocket event if server is available."""
try:
from server import PromptServer
if hasattr(PromptServer, "instance") and PromptServer.instance:
PromptServer.instance.send_sync(event_type, data)
except Exception:
pass
def _update_progress(
self,
scanned: int | None = None,
total: int | None = None,
created: int | None = None,
skipped: int | None = None,
) -> None:
"""Update progress counters (thread-safe)."""
callback: ProgressCallback | None = None
progress: Progress | None = None
with self._lock:
if self._progress is None:
return
if scanned is not None:
self._progress.scanned = scanned
if total is not None:
self._progress.total = total
if created is not None:
self._progress.created = created
if skipped is not None:
self._progress.skipped = skipped
if self._progress_callback:
callback = self._progress_callback
progress = Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if callback and progress:
try:
callback(progress)
except Exception:
pass
_MAX_ERRORS = 200
def _add_error(self, message: str) -> None:
"""Add an error message (thread-safe), capped at _MAX_ERRORS."""
with self._lock:
if len(self._errors) < self._MAX_ERRORS:
self._errors.append(message)
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
"""Log the directories that will be scanned."""
import folder_paths
for root in roots:
if root == "models":
logging.info(
"Asset scan [models] directory: %s",
os.path.abspath(folder_paths.models_dir),
)
else:
prefixes = get_prefixes_for_root(root)
if prefixes:
logging.info("Asset scan [%s] directories: %s", root, prefixes)
def _run_scan(self) -> None:
"""Main scan loop running in background thread."""
t_start = time.perf_counter()
roots = self._roots
phase = self._phase
cancelled = False
total_created = 0
total_enriched = 0
skipped_existing = 0
total_paths = 0
try:
if not dependencies_available():
self._add_error("Database dependencies not available")
self._emit_event(
"assets.seed.error",
{"message": "Database dependencies not available"},
)
return
if self._prune_first:
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d refs as missing before scan", marked)
if self._check_pause_and_cancel():
logging.info("Asset scan cancelled after pruning phase")
cancelled = True
return
self._log_scan_config(roots)
# Phase 1: Fast scan (stub records)
if phase in (ScanPhase.FAST, ScanPhase.FULL):
created, skipped, paths = self._run_fast_phase(roots)
total_created, skipped_existing, total_paths = created, skipped, paths
if self._check_pause_and_cancel():
cancelled = True
return
self._emit_event(
"assets.seed.fast_complete",
{
"roots": list(roots),
"created": total_created,
"skipped": skipped_existing,
"total": total_paths,
},
)
# Phase 2: Enrichment scan (metadata + hashes)
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
if self._check_pause_and_cancel():
cancelled = True
return
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
if enrich_cancelled:
cancelled = True
return
self._emit_event(
"assets.seed.enrich_complete",
{
"roots": list(roots),
"enriched": total_enriched,
},
)
elapsed = time.perf_counter() - t_start
logging.info(
"Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d",
roots,
phase.value,
elapsed,
total_created,
total_enriched,
skipped_existing,
)
self._emit_event(
"assets.seed.completed",
{
"phase": phase.value,
"total": total_paths,
"created": total_created,
"enriched": total_enriched,
"skipped": skipped_existing,
"elapsed": round(elapsed, 3),
},
)
except Exception as e:
self._add_error(f"Scan failed: {e}")
logging.exception("Asset scan failed")
self._emit_event("assets.seed.error", {"message": str(e)})
finally:
if cancelled:
self._emit_event(
"assets.seed.cancelled",
{
"scanned": self._progress.scanned if self._progress else 0,
"total": total_paths,
"created": total_created,
},
)
with self._lock:
self._reset_to_idle()
pending = self._pending_enrich
if pending is not None:
self._pending_enrich = None
if not self.start_enrich(
roots=pending["roots"],
compute_hashes=pending["compute_hashes"],
):
logging.warning(
"Pending enrich scan could not start (roots=%s)",
pending["roots"],
)
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
"""Run phase 1: fast scan to create stub records.
Returns:
Tuple of (total_created, skipped_existing, total_paths)
"""
t_fast_start = time.perf_counter()
total_created = 0
skipped_existing = 0
existing_paths: set[str] = set()
t_sync = time.perf_counter()
for r in roots:
if self._check_pause_and_cancel():
return total_created, skipped_existing, 0
existing_paths.update(sync_root_safely(r))
logging.debug(
"Fast scan: sync_root phase took %.3fs (%d existing paths)",
time.perf_counter() - t_sync,
len(existing_paths),
)
if self._check_pause_and_cancel():
return total_created, skipped_existing, 0
t_collect = time.perf_counter()
paths = collect_paths_for_roots(roots)
logging.debug(
"Fast scan: collect_paths took %.3fs (%d paths found)",
time.perf_counter() - t_collect,
len(paths),
)
total_paths = len(paths)
self._update_progress(total=total_paths)
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "total": total_paths, "phase": "fast"},
)
# Use stub specs (no metadata extraction, no hashing)
t_specs = time.perf_counter()
specs, tag_pool, skipped_existing = build_asset_specs(
paths,
existing_paths,
enable_metadata_extraction=False,
compute_hashes=False,
)
logging.debug(
"Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)",
time.perf_counter() - t_specs,
len(specs),
skipped_existing,
)
self._update_progress(skipped=skipped_existing)
if self._check_pause_and_cancel():
return total_created, skipped_existing, total_paths
batch_size = 500
last_progress_time = time.perf_counter()
progress_interval = 1.0
for i in range(0, len(specs), batch_size):
if self._check_pause_and_cancel():
logging.info(
"Fast scan cancelled after %d/%d files (created=%d)",
i,
len(specs),
total_created,
)
return total_created, skipped_existing, total_paths
batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]}
try:
created = insert_asset_specs(batch, batch_tags)
total_created += created
except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}")
logging.exception("Batch insert failed at offset %d", i)
scanned = i + len(batch)
now = time.perf_counter()
self._update_progress(scanned=scanned, created=total_created)
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"phase": "fast",
"scanned": scanned,
"total": len(specs),
"created": total_created,
},
)
last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created)
logging.info(
"Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)",
time.perf_counter() - t_fast_start,
total_created,
skipped_existing,
total_paths,
)
return total_created, skipped_existing, total_paths
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
"""Run phase 2: enrich existing records with metadata and hashes.
Returns:
Tuple of (cancelled, total_enriched)
"""
total_enriched = 0
batch_size = 100
last_progress_time = time.perf_counter()
progress_interval = 1.0
# Get the target enrichment level based on compute_hashes
if not self._compute_hashes:
target_max_level = ENRICHMENT_STUB
else:
target_max_level = ENRICHMENT_METADATA
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "phase": "enrich"},
)
skip_ids: set[str] = set()
consecutive_empty = 0
max_consecutive_empty = 3
# Hash checkpoints survive across batches so interrupted hashes
# can be resumed without re-reading the entire file.
hash_checkpoints: dict[str, object] = {}
while True:
if self._check_pause_and_cancel():
logging.info("Enrich scan cancelled after %d assets", total_enriched)
return True, total_enriched
# Fetch next batch of unenriched assets
unenriched = get_unenriched_assets_for_roots(
roots,
max_level=target_max_level,
limit=batch_size,
)
# Filter out previously failed references
if skip_ids:
unenriched = [r for r in unenriched if r.reference_id not in skip_ids]
if not unenriched:
break
enriched, failed_ids = enrich_assets_batch(
unenriched,
extract_metadata=True,
compute_hash=self._compute_hashes,
interrupt_check=self._is_paused_or_cancelled,
hash_checkpoints=hash_checkpoints,
)
total_enriched += enriched
skip_ids.update(failed_ids)
if enriched == 0:
consecutive_empty += 1
if consecutive_empty >= max_consecutive_empty:
logging.warning(
"Enrich phase stopping: %d consecutive batches with no progress (%d skipped)",
consecutive_empty,
len(skip_ids),
)
break
else:
consecutive_empty = 0
now = time.perf_counter()
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"phase": "enrich",
"enriched": total_enriched,
},
)
last_progress_time = now
return False, total_enriched
asset_seeder = _AssetSeeder()

View File

@@ -1,91 +0,0 @@
from app.assets.services.asset_management import (
asset_exists,
delete_asset_reference,
get_asset_by_hash,
get_asset_detail,
list_assets_page,
resolve_asset_for_download,
set_asset_preview,
update_asset_metadata,
)
from app.assets.services.bulk_ingest import (
BulkInsertResult,
batch_insert_seed_assets,
cleanup_unreferenced_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
get_size_and_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
ingest_existing_file,
register_output_files,
upload_from_temp_path,
)
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
)
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
IngestResult,
ListAssetsResult,
ReferenceData,
RegisterAssetResult,
TagUsage,
UploadResult,
UserMetadata,
)
from app.assets.services.tagging import (
apply_tags,
list_tags,
remove_tags,
)
__all__ = [
"AddTagsResult",
"AssetData",
"AssetDetailResult",
"AssetSummaryData",
"ReferenceData",
"BulkInsertResult",
"DependencyMissingError",
"DownloadResolutionResult",
"HashMismatchError",
"IngestResult",
"ListAssetsResult",
"RegisterAssetResult",
"RemoveTagsResult",
"TagUsage",
"UploadResult",
"UserMetadata",
"apply_tags",
"asset_exists",
"batch_insert_seed_assets",
"create_from_hash",
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"ingest_existing_file",
"register_output_files",
"get_mtime_ns",
"get_size_and_mtime_ns",
"list_assets_page",
"list_files_recursively",
"list_tags",
"cleanup_unreferenced_assets",
"remove_tags",
"resolve_asset_for_download",
"set_asset_preview",
"update_asset_metadata",
"upload_from_temp_path",
"verify_file_unchanged",
]

View File

@@ -1,367 +0,0 @@
import contextlib
import mimetypes
import os
from typing import Sequence
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
reference_exists_for_asset_id,
delete_reference_by_id,
fetch_reference_and_asset,
soft_delete_reference_by_id,
fetch_reference_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash,
get_reference_by_id,
get_reference_with_owner_check,
list_references_page,
list_all_file_paths_by_asset_id,
list_references_by_asset_id,
set_reference_metadata,
set_reference_preview,
set_reference_tags,
update_asset_hash_and_mime,
update_reference_access_time,
update_reference_name,
update_reference_updated_at,
)
from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_relative_filename
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
ListAssetsResult,
UserMetadata,
extract_asset_data,
extract_reference_data,
)
from app.database.db import create_session
def get_asset_detail(
reference_id: str,
owner_id: str = "",
) -> AssetDetailResult | None:
with create_session() as session:
result = fetch_reference_asset_and_tags(
session,
reference_id=reference_id,
owner_id=owner_id,
)
if not result:
return None
ref, asset, tags = result
return AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tags,
)
def update_asset_metadata(
reference_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult:
with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id)
touched = False
if name is not None and name != ref.name:
update_reference_name(session, reference_id=reference_id, name=name)
touched = True
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
new_meta: dict | None = None
if user_metadata is not None:
new_meta = dict(user_metadata)
elif computed_filename:
current_meta = ref.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
if new_meta is not None:
if computed_filename:
new_meta["filename"] = computed_filename
set_reference_metadata(
session, reference_id=reference_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_reference_tags(
session,
reference_id=reference_id,
tags=tags,
origin=tag_origin,
)
touched = True
if mime_type is not None:
updated = update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
if updated:
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_id,
)
touched = True
if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id)
result = fetch_reference_asset_and_tags(
session,
reference_id=reference_id,
owner_id=owner_id,
)
if not result:
raise RuntimeError("State changed during update")
ref, asset, tag_list = result
detail = AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_list,
)
session.commit()
return detail
def delete_asset_reference(
reference_id: str,
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
with create_session() as session:
if not delete_content_if_orphan:
# Soft delete: mark the reference as deleted but keep everything
deleted = soft_delete_reference_by_id(
session, reference_id=reference_id, owner_id=owner_id
)
session.commit()
return deleted
ref_row = get_reference_by_id(session, reference_id=reference_id)
asset_id = ref_row.asset_id if ref_row else None
file_path = ref_row.file_path if ref_row else None
deleted = delete_reference_by_id(
session, reference_id=reference_id, owner_id=owner_id
)
if not deleted:
session.commit()
return False
if not asset_id:
session.commit()
return True
still_exists = reference_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
# Orphaned asset - gather ALL file paths (including
# soft-deleted / missing refs) so their on-disk files get cleaned up.
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
# Also include the just-deleted file path
if file_path:
file_paths.append(file_path)
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
# Delete files after commit
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def set_asset_preview(
reference_id: str,
preview_reference_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
get_reference_with_owner_check(session, reference_id, owner_id)
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_reference_id,
)
result = fetch_reference_asset_and_tags(
session, reference_id=reference_id, owner_id=owner_id
)
if not result:
raise RuntimeError("State changed during preview update")
ref, asset, tags = result
detail = AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tags,
)
session.commit()
return detail
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
return extract_asset_data(asset)
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> ListAssetsResult:
with create_session() as session:
refs, tag_map, total = list_references_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
items: list[AssetSummaryData] = []
for ref in refs:
items.append(
AssetSummaryData(
ref=extract_reference_data(ref),
asset=extract_asset_data(ref.asset),
tags=tag_map.get(ref.id, []),
)
)
return ListAssetsResult(items=items, total=total)
def resolve_hash_to_path(
asset_hash: str,
owner_id: str = "",
) -> DownloadResolutionResult | None:
"""Resolve a blake3 hash to an on-disk file path.
Only references visible to *owner_id* are considered (owner-less
references are always visible).
Returns a DownloadResolutionResult with abs_path, content_type, and
download_name, or None if no asset or live path is found.
"""
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash)
if not asset:
return None
refs = list_references_by_asset_id(session, asset_id=asset.id)
visible = [
r for r in refs
if r.owner_id == "" or r.owner_id == owner_id
]
abs_path = select_best_live_path(visible)
if not abs_path:
return None
display_name = os.path.basename(abs_path)
for ref in visible:
if ref.file_path == abs_path and ref.name:
display_name = ref.name
break
ctype = (
asset.mime_type
or mimetypes.guess_type(display_name)[0]
or "application/octet-stream"
)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=display_name,
)
def resolve_asset_for_download(
reference_id: str,
owner_id: str = "",
) -> DownloadResolutionResult:
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise ValueError(f"AssetReference {reference_id} not found")
ref, asset = pair
# For references with file_path, use that directly
if ref.file_path and os.path.isfile(ref.file_path):
abs_path = ref.file_path
else:
# For API-created refs without file_path, find a path from other refs
refs = list_references_by_asset_id(session, asset_id=asset.id)
abs_path = select_best_live_path(refs)
if not abs_path:
raise FileNotFoundError(
f"No live path for AssetReference {reference_id} "
f"(asset id={asset.id}, name={ref.name})"
)
# Capture ORM attributes before commit (commit expires loaded objects)
ref_name = ref.name
asset_mime = asset.mime_type
update_reference_access_time(session, reference_id=reference_id)
session.commit()
ctype = (
asset_mime
or mimetypes.guess_type(ref_name or abs_path)[0]
or "application/octet-stream"
)
download_name = ref_name or os.path.basename(abs_path)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=download_name,
)

View File

@@ -1,283 +0,0 @@
from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
from app.assets.database.queries import (
bulk_insert_assets,
bulk_insert_references_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
get_existing_asset_ids,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_unreferenced_unhashed_asset_ids,
restore_references_by_paths,
)
from app.assets.helpers import get_utc_now
if TYPE_CHECKING:
from app.assets.services.metadata_extract import ExtractedMetadata
class SeedAssetSpec(TypedDict):
"""Spec for seeding an asset from filesystem."""
abs_path: str
size_bytes: int
mtime_ns: int
info_name: str
tags: list[str]
fname: str
metadata: ExtractedMetadata | None
hash: str | None
mime_type: str | None
job_id: str | None
class AssetRow(TypedDict):
"""Row data for inserting an Asset."""
id: str
hash: str | None
size_bytes: int
mime_type: str | None
created_at: datetime
class ReferenceRow(TypedDict):
"""Row data for inserting an AssetReference."""
id: str
asset_id: str
file_path: str
mtime_ns: int
owner_id: str
name: str
preview_id: str | None
user_metadata: dict[str, Any] | None
job_id: str | None
created_at: datetime
updated_at: datetime
last_access_time: datetime
class TagRow(TypedDict):
"""Row data for inserting a Tag."""
asset_reference_id: str
tag_name: str
origin: str
added_at: datetime
class MetadataRow(TypedDict):
"""Row data for inserting asset metadata."""
asset_reference_id: str
key: str
ordinal: int
val_str: str | None
val_num: float | None
val_bool: bool | None
val_json: dict[str, Any] | None
@dataclass
class BulkInsertResult:
"""Result of bulk asset insertion."""
inserted_refs: int
won_paths: int
lost_paths: int
def batch_insert_seed_assets(
session: Session,
specs: list[SeedAssetSpec],
owner_id: str = "",
) -> BulkInsertResult:
"""Seed assets from filesystem specs in batch.
Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
This function orchestrates:
1. Insert seed Assets (hash=NULL)
2. Claim references with ON CONFLICT DO NOTHING on file_path
3. Query to find winners (paths where our asset_id was inserted)
4. Delete Assets for losers (path already claimed by another asset)
5. Insert tags and metadata for successfully inserted references
Returns:
BulkInsertResult with inserted_refs, won_paths, lost_paths
"""
if not specs:
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
current_time = get_utc_now()
asset_rows: list[AssetRow] = []
reference_rows: list[ReferenceRow] = []
path_to_asset_id: dict[str, str] = {}
asset_id_to_ref_data: dict[str, dict] = {}
absolute_path_list: list[str] = []
for spec in specs:
absolute_path = os.path.abspath(spec["abs_path"])
asset_id = str(uuid.uuid4())
reference_id = str(uuid.uuid4())
absolute_path_list.append(absolute_path)
path_to_asset_id[absolute_path] = asset_id
mime_type = spec.get("mime_type")
asset_rows.append(
{
"id": asset_id,
"hash": spec.get("hash"),
"size_bytes": spec["size_bytes"],
"mime_type": mime_type,
"created_at": current_time,
}
)
# Build user_metadata from extracted metadata or fallback to filename
extracted_metadata = spec.get("metadata")
if extracted_metadata:
user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata()
elif spec["fname"]:
user_metadata = {"filename": spec["fname"]}
else:
user_metadata = None
reference_rows.append(
{
"id": reference_id,
"asset_id": asset_id,
"file_path": absolute_path,
"mtime_ns": spec["mtime_ns"],
"owner_id": owner_id,
"name": spec["info_name"],
"preview_id": None,
"user_metadata": user_metadata,
"job_id": spec.get("job_id"),
"created_at": current_time,
"updated_at": current_time,
"last_access_time": current_time,
}
)
asset_id_to_ref_data[asset_id] = {
"reference_id": reference_id,
"tags": spec["tags"],
"filename": spec["fname"],
"extracted_metadata": extracted_metadata,
}
bulk_insert_assets(session, asset_rows)
# Filter reference rows to only those whose assets were actually inserted
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
inserted_asset_ids = get_existing_asset_ids(
session, [r["asset_id"] for r in reference_rows]
)
reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids]
bulk_insert_references_ignore_conflicts(session, reference_rows)
restore_references_by_paths(session, absolute_path_list)
winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id)
inserted_paths = {
path
for path in absolute_path_list
if path_to_asset_id[path] in inserted_asset_ids
}
losing_paths = inserted_paths - winning_paths
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
if lost_asset_ids:
delete_assets_by_ids(session, lost_asset_ids)
if not winning_paths:
return BulkInsertResult(
inserted_refs=0,
won_paths=0,
lost_paths=len(losing_paths),
)
# Get reference IDs for winners
winning_ref_ids = [
asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"]
for path in winning_paths
]
inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids)
tag_rows: list[TagRow] = []
metadata_rows: list[MetadataRow] = []
if inserted_ref_ids:
for path in winning_paths:
asset_id = path_to_asset_id[path]
ref_data = asset_id_to_ref_data[asset_id]
ref_id = ref_data["reference_id"]
if ref_id not in inserted_ref_ids:
continue
for tag in ref_data["tags"]:
tag_rows.append(
{
"asset_reference_id": ref_id,
"tag_name": tag,
"origin": "automatic",
"added_at": current_time,
}
)
# Use extracted metadata for meta rows if available
extracted_metadata = ref_data.get("extracted_metadata")
if extracted_metadata:
metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id))
elif ref_data["filename"]:
# Fallback: just store filename
metadata_rows.append(
{
"asset_reference_id": ref_id,
"key": "filename",
"ordinal": 0,
"val_str": ref_data["filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
return BulkInsertResult(
inserted_refs=len(inserted_ref_ids),
won_paths=len(winning_paths),
lost_paths=len(losing_paths),
)
def cleanup_unreferenced_assets(session: Session) -> int:
"""Hard-delete unhashed assets with no active references.
This is a destructive operation intended for explicit cleanup.
Only deletes assets where hash=None and all references are missing.
Returns:
Number of assets deleted
"""
unreferenced_ids = get_unreferenced_unhashed_asset_ids(session)
return delete_assets_by_ids(session, unreferenced_ids)

View File

@@ -1,70 +0,0 @@
import os
def get_mtime_ns(stat_result: os.stat_result) -> int:
"""Extract mtime in nanoseconds from a stat result."""
return getattr(
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
)
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
"""Get file size in bytes and mtime in nanoseconds."""
st = os.stat(path, follow_symlinks=follow_symlinks)
return st.st_size, get_mtime_ns(st)
def verify_file_unchanged(
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
"""Check if a file is unchanged based on mtime and size.
Returns True if the file's mtime and size match the database values.
Returns False if mtime_db is None or values don't match.
size_db=None means don't check size; 0 is a valid recorded size.
"""
if mtime_db is None:
return False
actual_mtime_ns = get_mtime_ns(stat_result)
if int(mtime_db) != int(actual_mtime_ns):
return False
if size_db is not None:
return int(stat_result.st_size) == int(size_db)
return True
def is_visible(name: str) -> bool:
"""Return True if a file or directory name is visible (not hidden)."""
return not name.startswith(".")
def list_files_recursively(base_dir: str) -> list[str]:
"""Recursively list all files in a directory, following symlinks."""
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
# Track seen real directory identities to prevent circular symlink loops
seen_dirs: set[tuple[int, int]] = set()
for dirpath, subdirs, filenames in os.walk(
base_abs, topdown=True, followlinks=True
):
try:
st = os.stat(dirpath)
dir_id = (st.st_dev, st.st_ino)
except OSError:
subdirs.clear()
continue
if dir_id in seen_dirs:
subdirs.clear()
continue
seen_dirs.add(dir_id)
subdirs[:] = [d for d in subdirs if is_visible(d)]
for name in filenames:
if not is_visible(name):
continue
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out

View File

@@ -1,99 +0,0 @@
import io
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import IO, Any, Callable, Iterator
import logging
try:
from blake3 import blake3
except ModuleNotFoundError:
logging.warning("WARNING: blake3 package not installed")
DEFAULT_CHUNK = 8 * 1024 * 1024
InterruptCheck = Callable[[], bool]
@dataclass
class HashCheckpoint:
"""Saved state for resuming an interrupted hash computation."""
bytes_processed: int
hasher: Any # blake3 hasher instance
mtime_ns: int = 0
file_size: int = 0
@contextmanager
def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]:
"""Yield (file_object, is_path) with appropriate setup/teardown."""
if hasattr(fp, "read"):
seekable = getattr(fp, "seekable", lambda: False)()
orig_pos = None
if seekable:
try:
orig_pos = fp.tell()
if orig_pos != 0:
fp.seek(0)
except io.UnsupportedOperation:
orig_pos = None
try:
yield fp, False
finally:
if orig_pos is not None:
fp.seek(orig_pos)
else:
with open(os.fspath(fp), "rb") as f:
yield f, True
def compute_blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
interrupt_check: InterruptCheck | None = None,
checkpoint: HashCheckpoint | None = None,
) -> tuple[str | None, HashCheckpoint | None]:
"""Compute BLAKE3 hash of a file, with optional checkpoint support.
Args:
fp: File path or file-like object
chunk_size: Size of chunks to read at a time
interrupt_check: Optional callable that returns True if the operation
should be interrupted (e.g. paused or cancelled). Must be
non-blocking so file handles are released immediately. Checked
between chunk reads.
checkpoint: Optional checkpoint to resume from (file paths only)
Returns:
Tuple of (hex_digest, None) on completion, or
(None, checkpoint) on interruption (file paths only), or
(None, None) on interruption of a file object
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
with _open_for_hashing(fp) as (f, is_path):
if checkpoint is not None and is_path:
f.seek(checkpoint.bytes_processed)
h = checkpoint.hasher
bytes_processed = checkpoint.bytes_processed
else:
h = blake3()
bytes_processed = 0
while True:
if interrupt_check is not None and interrupt_check():
if is_path:
return None, HashCheckpoint(
bytes_processed=bytes_processed,
hasher=h,
)
return None, None
chunk = f.read(chunk_size)
if not chunk:
break
h.update(chunk)
bytes_processed += len(chunk)
return h.hexdigest(), None

View File

@@ -1,563 +0,0 @@
import contextlib
import logging
import mimetypes
import os
from typing import Any, Sequence
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.queries import (
add_tags_to_reference,
count_active_siblings,
create_stub_asset,
ensure_tags_exist,
fetch_reference_and_asset,
get_asset_by_hash,
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_tags,
update_asset_hash_and_mime,
upsert_asset,
upsert_reference,
validate_tags_exist,
)
from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
validate_path_within_base,
)
from app.assets.services.schemas import (
IngestResult,
RegisterAssetResult,
UploadResult,
UserMetadata,
extract_asset_data,
extract_reference_data,
)
from app.database.db import create_session
def _ingest_file_from_path(
abs_path: str,
asset_hash: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: UserMetadata = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> IngestResult:
locator = os.path.abspath(abs_path)
user_metadata = user_metadata or {}
asset_created = False
asset_updated = False
ref_created = False
ref_updated = False
reference_id: str | None = None
with create_session() as session:
if preview_id:
if not reference_exists(session, preview_id):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=size_bytes,
mime_type=mime_type,
)
ref_created, ref_updated = upsert_reference(
session,
asset_id=asset.id,
file_path=locator,
name=info_name or os.path.basename(locator),
mtime_ns=mtime_ns,
owner_id=owner_id,
)
# Get the reference we just created/updated
ref = get_reference_by_file_path(session, locator)
if ref:
reference_id = ref.id
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
norm = normalize_tags(list(tags))
if norm:
if require_existing_tags:
validate_tags_exist(session, norm)
add_tags_to_reference(
session,
reference_id=reference_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
_update_metadata_with_filename(
session,
reference_id=reference_id,
file_path=ref.file_path,
current_metadata=ref.user_metadata,
user_metadata=user_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
session.commit()
return IngestResult(
asset_created=asset_created,
asset_updated=asset_updated,
ref_created=ref_created,
ref_updated=ref_updated,
reference_id=reference_id,
)
def register_output_files(
file_paths: Sequence[str],
user_metadata: UserMetadata = None,
job_id: str | None = None,
) -> int:
"""Register a batch of output file paths as assets.
Returns the number of files successfully registered.
"""
registered = 0
for abs_path in file_paths:
if not os.path.isfile(abs_path):
continue
try:
if ingest_existing_file(
abs_path, user_metadata=user_metadata, job_id=job_id
):
registered += 1
except Exception:
logging.exception("Failed to register output: %s", abs_path)
return registered
def ingest_existing_file(
abs_path: str,
user_metadata: UserMetadata = None,
extra_tags: Sequence[str] = (),
owner_id: str = "",
job_id: str | None = None,
) -> bool:
"""Register an existing on-disk file as an asset stub.
If a reference already exists for this path, updates mtime_ns, job_id,
size_bytes, and resets enrichment so the enricher will re-hash it.
For brand-new paths, inserts a stub record (hash=NULL) for immediate
UX visibility.
Returns True if a row was inserted or updated, False otherwise.
"""
locator = os.path.abspath(abs_path)
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
with create_session() as session:
existing_ref = get_reference_by_file_path(session, locator)
if existing_ref is not None:
now = get_utc_now()
existing_ref.mtime_ns = mtime_ns
existing_ref.job_id = job_id
existing_ref.is_missing = False
existing_ref.deleted_at = None
existing_ref.updated_at = now
existing_ref.enrichment_level = 0
asset = existing_ref.asset
if asset:
# If other refs share this asset, detach to a new stub
# instead of mutating the shared row.
siblings = count_active_siblings(session, asset.id, existing_ref.id)
if siblings > 0:
new_asset = create_stub_asset(
session,
size_bytes=size_bytes,
mime_type=mime_type or asset.mime_type,
)
existing_ref.asset_id = new_asset.id
else:
asset.hash = None
asset.size_bytes = size_bytes
if mime_type:
asset.mime_type = mime_type
session.commit()
return True
spec = {
"abs_path": abs_path,
"size_bytes": size_bytes,
"mtime_ns": mtime_ns,
"info_name": name,
"tags": tags,
"fname": os.path.basename(abs_path),
"metadata": None,
"hash": None,
"mime_type": mime_type,
"job_id": job_id,
}
if tags:
ensure_tags_exist(session, tags)
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
session.commit()
return result.won_paths > 0
def _register_existing_asset(
asset_hash: str,
name: str,
user_metadata: UserMetadata = None,
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> RegisterAssetResult:
user_metadata = user_metadata or {}
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
if mime_type and not asset.mime_type:
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
if preview_id:
if not reference_exists(session, preview_id):
preview_id = None
ref, ref_created = get_or_create_reference(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
preview_id=preview_id,
)
if not ref_created:
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created=False,
)
session.commit()
return result
new_meta = dict(user_metadata)
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
set_reference_metadata(
session,
reference_id=ref.id,
user_metadata=new_meta,
)
if tags is not None:
set_reference_tags(
session,
reference_id=ref.id,
tags=tags,
origin=tag_origin,
)
tag_names = get_reference_tags(session, reference_id=ref.id)
session.refresh(ref)
result = RegisterAssetResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created=True,
)
session.commit()
return result
def _update_metadata_with_filename(
session: Session,
reference_id: str,
file_path: str | None,
current_metadata: dict | None,
user_metadata: dict[str, Any],
) -> None:
computed_filename = compute_relative_filename(file_path) if file_path else None
current_meta = current_metadata or {}
new_meta = dict(current_meta)
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
set_reference_metadata(
session,
reference_id=reference_id,
user_metadata=new_meta,
)
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback
class HashMismatchError(Exception):
pass
class DependencyMissingError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(message)
def upload_from_temp_path(
temp_path: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult:
try:
digest, _ = hashing.compute_blake3_hash(temp_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_hash and asset_hash != expected_hash.strip().lower():
raise HashMismatchError("Uploaded file hash does not match provided hash.")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _sanitize_filename(name or client_filename, fallback=digest)
result = _register_existing_asset(
asset_hash=asset_hash,
name=display_name,
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
)
return UploadResult(
ref=result.ref,
asset=result.asset,
tags=result.tags,
created_new=False,
)
if not tags:
raise ValueError("tags are required for new asset uploads")
base_dir, subdirs = resolve_destination_from_tags(tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = mime_type or (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
ingest_result = _ingest_file_from_path(
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=preview_id,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def register_file_in_place(
abs_path: str,
name: str,
tags: list[str],
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it.
Tags are derived from the filesystem path (root category + subfolder names),
merged with any caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used.
"""
try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError:
path_tags = []
merged_tags = normalize_tags([*path_tags, *tags])
try:
digest, _ = hashing.compute_blake3_hash(abs_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash file: {e}")
asset_hash = "blake3:" + digest
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
content_type = mime_type or (
mimetypes.guess_type(abs_path, strict=False)[0]
or "application/octet-stream"
)
ingest_result = _ingest_file_from_path(
abs_path=abs_path,
asset_hash=asset_hash,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name, fallback=digest),
owner_id=owner_id,
tags=merged_tags,
tag_origin="upload",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash(
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult | None:
canonical = hash_str.strip().lower()
try:
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
)
except ValueError:
logging.warning("create_from_hash: no asset found for hash %s", canonical)
return None
return UploadResult(
ref=result.ref,
asset=result.asset,
tags=result.tags,
created_new=False,
)

View File

@@ -1,327 +0,0 @@
"""Metadata extraction for asset scanning.
Tier 1: Filesystem metadata (zero parsing)
Tier 2: Safetensors header metadata (fast JSON read only)
"""
from __future__ import annotations
import json
import logging
import mimetypes
import os
import struct
from dataclasses import dataclass
from typing import Any
from utils.mime_types import init_mime_types
init_mime_types()
# Supported safetensors extensions
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
# Maximum safetensors header size to read (8MB)
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
@dataclass
class ExtractedMetadata:
"""Metadata extracted from a file during scanning."""
# Tier 1: Filesystem (always available)
filename: str = ""
file_path: str = "" # Full absolute path to the file
content_length: int = 0
content_type: str | None = None
format: str = "" # file extension without dot
# Tier 2: Safetensors header (if available)
base_model: str | None = None
trained_words: list[str] | None = None
air: str | None = None # CivitAI AIR identifier
has_preview_images: bool = False
# Source provenance (populated if embedded in safetensors)
source_url: str | None = None
source_arn: str | None = None
repo_url: str | None = None
preview_url: str | None = None
source_hash: str | None = None
# HuggingFace specific
repo_id: str | None = None
revision: str | None = None
filepath: str | None = None
resolve_url: str | None = None
def to_user_metadata(self) -> dict[str, Any]:
"""Convert to user_metadata dict for AssetReference.user_metadata JSON field."""
data: dict[str, Any] = {
"filename": self.filename,
"content_length": self.content_length,
"format": self.format,
}
if self.file_path:
data["file_path"] = self.file_path
if self.content_type:
data["content_type"] = self.content_type
# Tier 2 fields
if self.base_model:
data["base_model"] = self.base_model
if self.trained_words:
data["trained_words"] = self.trained_words
if self.air:
data["air"] = self.air
if self.has_preview_images:
data["has_preview_images"] = True
# Source provenance
if self.source_url:
data["source_url"] = self.source_url
if self.source_arn:
data["source_arn"] = self.source_arn
if self.repo_url:
data["repo_url"] = self.repo_url
if self.preview_url:
data["preview_url"] = self.preview_url
if self.source_hash:
data["source_hash"] = self.source_hash
# HuggingFace
if self.repo_id:
data["repo_id"] = self.repo_id
if self.revision:
data["revision"] = self.revision
if self.filepath:
data["filepath"] = self.filepath
if self.resolve_url:
data["resolve_url"] = self.resolve_url
return data
def to_meta_rows(self, reference_id: str) -> list[dict]:
"""Convert to asset_reference_meta rows for typed/indexed querying."""
rows: list[dict] = []
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
if val:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": ordinal,
"val_str": val[:2048] if len(val) > 2048 else val,
"val_num": None,
"val_bool": None,
"val_json": None,
})
def add_num(key: str, val: int | float | None) -> None:
if val is not None:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": val,
"val_bool": None,
"val_json": None,
})
def add_bool(key: str, val: bool | None) -> None:
if val is not None:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": None,
"val_bool": val,
"val_json": None,
})
# Tier 1
add_str("filename", self.filename)
add_num("content_length", self.content_length)
add_str("content_type", self.content_type)
add_str("format", self.format)
# Tier 2
add_str("base_model", self.base_model)
add_str("air", self.air)
has_previews = self.has_preview_images if self.has_preview_images else None
add_bool("has_preview_images", has_previews)
# trained_words as multiple rows with ordinals
if self.trained_words:
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
add_str("trained_words", word, ordinal=i)
# Source provenance
add_str("source_url", self.source_url)
add_str("source_arn", self.source_arn)
add_str("repo_url", self.repo_url)
add_str("preview_url", self.preview_url)
add_str("source_hash", self.source_hash)
# HuggingFace
add_str("repo_id", self.repo_id)
add_str("revision", self.revision)
add_str("filepath", self.filepath)
add_str("resolve_url", self.resolve_url)
return rows
def _read_safetensors_header(
path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE
) -> dict[str, Any] | None:
"""Read only the JSON header from a safetensors file.
This is very fast - reads 8 bytes for header length, then the JSON header.
No tensor data is loaded.
Args:
path: Absolute path to safetensors file
max_size: Maximum header size to read (default 8MB)
Returns:
Parsed header dict or None if failed
"""
try:
with open(path, "rb") as f:
header_bytes = f.read(8)
if len(header_bytes) < 8:
return None
length_of_header = struct.unpack("<Q", header_bytes)[0]
if length_of_header > max_size:
return None
header_data = f.read(length_of_header)
if len(header_data) < length_of_header:
return None
return json.loads(header_data.decode("utf-8"))
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
return None
def _extract_safetensors_metadata(
header: dict[str, Any], meta: ExtractedMetadata
) -> None:
"""Extract metadata from safetensors header __metadata__ section.
Modifies meta in-place.
"""
st_meta = header.get("__metadata__", {})
if not isinstance(st_meta, dict):
return
# Common model metadata
meta.base_model = (
st_meta.get("ss_base_model_version")
or st_meta.get("modelspec.base_model")
or st_meta.get("base_model")
)
# Trained words / trigger words
trained_words = st_meta.get("ss_tag_frequency")
if trained_words and isinstance(trained_words, str):
try:
tag_freq = json.loads(trained_words)
# Extract unique tags from all datasets
all_tags: set[str] = set()
for dataset_tags in tag_freq.values():
if isinstance(dataset_tags, dict):
all_tags.update(dataset_tags.keys())
if all_tags:
meta.trained_words = sorted(all_tags)[:100]
except json.JSONDecodeError:
pass
# Direct trained_words field (some formats)
if not meta.trained_words:
tw = st_meta.get("trained_words")
if isinstance(tw, str):
try:
parsed = json.loads(tw)
if isinstance(parsed, list):
meta.trained_words = [str(x) for x in parsed]
else:
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
except json.JSONDecodeError:
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
elif isinstance(tw, list):
meta.trained_words = [str(x) for x in tw]
# CivitAI AIR
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
# Preview images (ssmd_cover_images)
cover_images = st_meta.get("ssmd_cover_images")
if cover_images:
meta.has_preview_images = True
# Source provenance fields
meta.source_url = st_meta.get("source_url")
meta.source_arn = st_meta.get("source_arn")
meta.repo_url = st_meta.get("repo_url")
meta.preview_url = st_meta.get("preview_url")
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
# HuggingFace fields
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
def extract_file_metadata(
abs_path: str,
stat_result: os.stat_result | None = None,
relative_filename: str | None = None,
) -> ExtractedMetadata:
"""Extract metadata from a file using tier 1 and tier 2 methods.
Tier 1: Filesystem metadata from path and stat
Tier 2: Safetensors header parsing if applicable
Args:
abs_path: Absolute path to the file
stat_result: Optional pre-fetched stat result (saves a syscall)
relative_filename: Optional relative filename to use instead of basename
(e.g., "flux/123/model.safetensors" for model paths)
Returns:
ExtractedMetadata with all available fields populated
"""
meta = ExtractedMetadata()
# Tier 1: Filesystem metadata
meta.filename = relative_filename or os.path.basename(abs_path)
meta.file_path = abs_path
_, ext = os.path.splitext(abs_path)
meta.format = ext.lstrip(".").lower() if ext else ""
mime_type, _ = mimetypes.guess_type(abs_path)
meta.content_type = mime_type
# Size from stat
if stat_result is None:
try:
stat_result = os.stat(abs_path, follow_symlinks=True)
except OSError:
pass
if stat_result:
meta.content_length = stat_result.st_size
# Tier 2: Safetensors header (if applicable and enabled)
if ext.lower() in SAFETENSORS_EXTENSIONS:
header = _read_safetensors_header(abs_path)
if header:
try:
_extract_safetensors_metadata(header, meta)
except Exception as e:
logging.debug("Safetensors meta extract failed %s: %s", abs_path, e)
return meta

View File

@@ -1,173 +0,0 @@
import os
from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build list of (folder_name, base_paths[]) for all model locations.
Includes every category registered in folder_names_and_paths,
regardless of whether its paths are under the main models_dir,
but excludes non-model entries like custom_nodes.
"""
targets: list[tuple[str, list[str]]] = []
for name, values in folder_paths.folder_names_and_paths.items():
if name in _NON_MODEL_FOLDER_NAMES:
continue
paths, _exts = values[0], values[1]
if paths:
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
if not tags:
raise ValueError("tags must not be empty")
root = tags[0].lower()
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory())
raw_subdirs = tags[1:]
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = tags[1:]
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs:
if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def validate_path_within_base(candidate: str, base: str) -> None:
cand_abs = Path(os.path.abspath(candidate))
base_abs = Path(os.path.abspath(base))
if not cand_abs.is_relative_to(base_abs):
raise ValueError("destination escapes base directory")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "temp", "models"], str]:
"""Determine which root category a file path belongs to.
Categories:
- 'input': under folder_paths.get_input_directory()
- 'output': under folder_paths.get_output_directory()
- 'temp': under folder_paths.get_temp_directory()
- 'models': under any base path from get_comfy_models_folders()
Returns:
(root_category, relative_path_inside_that_root)
Raises:
ValueError: path does not belong to any known root.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _check_is_within(fp_abs, input_base):
return "input", _compute_relative(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) temp
temp_base = os.path.abspath(folder_paths.get_temp_directory())
if _check_is_within(fp_abs, temp_base):
return "temp", _compute_relative(fp_abs, temp_base)
# 4) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path.
- name: base filename with extension
- tags: [root_category] + parent folder names in order
Raises:
ValueError: path does not belong to any known root.
"""
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@@ -1,113 +0,0 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NamedTuple
from app.assets.database.models import Asset, AssetReference
UserMetadata = dict[str, Any] | None
@dataclass(frozen=True)
class AssetData:
hash: str | None
size_bytes: int | None
mime_type: str | None
@dataclass(frozen=True)
class ReferenceData:
"""Data transfer object for AssetReference."""
id: str
name: str
file_path: str | None
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
updated_at: datetime
system_metadata: dict[str, Any] | None = None
job_id: str | None = None
last_access_time: datetime | None = None
@dataclass(frozen=True)
class AssetDetailResult:
ref: ReferenceData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class RegisterAssetResult:
ref: ReferenceData
asset: AssetData
tags: list[str]
created: bool
@dataclass(frozen=True)
class IngestResult:
asset_created: bool
asset_updated: bool
ref_created: bool
ref_updated: bool
reference_id: str | None
class TagUsage(NamedTuple):
name: str
tag_type: str
count: int
@dataclass(frozen=True)
class AssetSummaryData:
ref: ReferenceData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
@dataclass(frozen=True)
class DownloadResolutionResult:
abs_path: str
content_type: str
download_name: str
@dataclass(frozen=True)
class UploadResult:
ref: ReferenceData
asset: AssetData
tags: list[str]
created_new: bool
def extract_reference_data(ref: AssetReference) -> ReferenceData:
return ReferenceData(
id=ref.id,
name=ref.name,
file_path=ref.file_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
job_id=ref.job_id,
created_at=ref.created_at,
updated_at=ref.updated_at,
last_access_time=ref.last_access_time,
)
def extract_asset_data(asset: Asset | None) -> AssetData | None:
if asset is None:
return None
return AssetData(
hash=asset.hash,
size_bytes=asset.size_bytes,
mime_type=asset.mime_type,
)

View File

@@ -1,98 +0,0 @@
from typing import Sequence
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
add_tags_to_reference,
get_reference_with_owner_check,
list_tags_with_usage,
remove_tags_from_reference,
)
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage
from app.database.db import create_session
def apply_tags(
reference_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> AddTagsResult:
with create_session() as session:
ref_row = get_reference_with_owner_check(session, reference_id, owner_id)
result = add_tags_to_reference(
session,
reference_id=reference_id,
tags=tags,
origin=origin,
create_if_missing=True,
reference_row=ref_row,
)
session.commit()
return result
def remove_tags(
reference_id: str,
tags: list[str],
owner_id: str = "",
) -> RemoveTagsResult:
with create_session() as session:
get_reference_with_owner_check(session, reference_id, owner_id)
result = remove_tags_from_reference(
session,
reference_id=reference_id,
tags=tags,
)
session.commit()
return result
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> tuple[list[TagUsage], int]:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@@ -3,7 +3,6 @@ import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from filelock import FileLock, Timeout
from comfy.cli_args import args
_DB_AVAILABLE = False
@@ -15,12 +14,8 @@ try:
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, event
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database.models import Base
import app.assets.database.models # noqa: F401 — register models with Base.metadata
_DB_AVAILABLE = True
except ImportError as e:
@@ -70,69 +65,9 @@ def get_db_path():
raise ValueError(f"Unsupported database URL '{url}'.")
_db_lock = None
def _acquire_file_lock(db_path):
"""Acquire an OS-level file lock to prevent multi-process access.
Uses filelock for cross-platform support (macOS, Linux, Windows).
The OS automatically releases the lock when the process exits, even on crashes.
"""
global _db_lock
lock_path = db_path + ".lock"
_db_lock = FileLock(lock_path)
try:
_db_lock.acquire(timeout=0)
except Timeout:
raise RuntimeError(
f"Could not acquire lock on database '{db_path}'. "
"Another ComfyUI process may already be using it. "
"Use --database-url to specify a separate database file."
)
def _is_memory_db(db_url):
"""Check if the database URL refers to an in-memory SQLite database."""
return db_url in ("sqlite:///:memory:", "sqlite://")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
if _is_memory_db(db_url):
_init_memory_db(db_url)
else:
_init_file_db(db_url)
def _init_memory_db(db_url):
"""Initialize an in-memory SQLite database using metadata.create_all.
Alembic migrations don't work with in-memory SQLite because each
connection gets its own separate database — tables created by Alembic's
internal connection are lost immediately.
"""
engine = create_engine(
db_url,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Base.metadata.create_all(engine)
global Session
Session = sessionmaker(bind=engine)
def _init_file_db(db_url):
"""Initialize a file-backed SQLite database using Alembic migrations."""
db_path = get_db_path()
db_exists = os.path.exists(db_path)
@@ -140,14 +75,6 @@ def _init_file_db(db_url):
# Check if we need to upgrade
engine = create_engine(db_url)
# Enable foreign key enforcement for SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
conn = engine.connect()
context = MigrationContext.configure(conn)
@@ -177,12 +104,6 @@ def _init_file_db(db_url):
logging.exception("Error upgrading database: ")
raise e
# Acquire an OS-level file lock after migrations are complete.
# Alembic uses its own connection, so we must wait until it's done
# before locking — otherwise our own lock blocks the migration.
conn.close()
_acquire_file_lock(db_path)
global Session
Session = sessionmaker(bind=engine)

View File

@@ -1,18 +1,9 @@
from typing import Any
from datetime import datetime
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
class Base(DeclarativeBase):
metadata = MetaData(naming_convention=NAMING_CONVENTION)
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()

View File

@@ -6,7 +6,6 @@ import uuid
import glob
import shutil
import logging
import tempfile
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
@@ -378,15 +377,8 @@ class UserManager():
try:
body = await request.read()
dir_name = os.path.dirname(path)
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
try:
with os.fdopen(fd, "wb") as f:
f.write(body)
os.replace(tmp_path, path)
except:
os.unlink(tmp_path)
raise
with open(path, "wb") as f:
f.write(body)
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(

View File

@@ -1,90 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform float u_float0;
uniform float u_float1;
uniform float u_float2;
uniform float u_float3;
uniform float u_float4;
uniform float u_float5;
uniform float u_float6;
uniform float u_float7;
uniform float u_float8;
uniform bool u_bool0;
in vec2 v_texCoord;
out vec4 fragColor;
vec3 rgb2hsl(vec3 c) {
float maxC = max(c.r, max(c.g, c.b));
float minC = min(c.r, min(c.g, c.b));
float l = (maxC + minC) * 0.5;
if (maxC == minC) return vec3(0.0, 0.0, l);
float d = maxC - minC;
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
float h;
if (maxC == c.r) {
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
} else if (maxC == c.g) {
h = (c.b - c.r) / d + 2.0;
} else {
h = (c.r - c.g) / d + 4.0;
}
h /= 6.0;
return vec3(h, s, l);
}
float hue2rgb(float p, float q, float t) {
if (t < 0.0) t += 1.0;
if (t > 1.0) t -= 1.0;
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
if (t < 1.0 / 2.0) return q;
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
return p;
}
vec3 hsl2rgb(vec3 hsl) {
float h = hsl.x, s = hsl.y, l = hsl.z;
if (s == 0.0) return vec3(l);
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
float p = 2.0 * l - q;
return vec3(
hue2rgb(p, q, h + 1.0 / 3.0),
hue2rgb(p, q, h),
hue2rgb(p, q, h - 1.0 / 3.0)
);
}
void main() {
vec4 tex = texture(u_image0, v_texCoord);
vec3 color = tex.rgb;
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
float maxC = max(color.r, max(color.g, color.b));
float minC = min(color.r, min(color.g, color.b));
float lightness = (maxC + minC) * 0.5;
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
const float a = 0.25;
const float b = 0.333;
const float scale = 0.7;
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
color += sw * shadows + mw * midtones + hw * highlights;
if (u_bool0) {
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
hsl.z = lightness;
color = hsl2rgb(hsl);
}
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
}

View File

@@ -1,49 +0,0 @@
#version 300 es
precision highp float;
uniform sampler2D u_image0;
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
uniform sampler2D u_curve1; // Red channel curve
uniform sampler2D u_curve2; // Green channel curve
uniform sampler2D u_curve3; // Blue channel curve
in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0;
// GIMP-compatible curve lookup with manual linear interpolation.
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
// index = value * (n_samples - 1)
// f = fract(index)
// result = (1-f) * samples[floor] + f * samples[ceil]
//
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
float applyCurve(sampler2D curve, float value) {
value = clamp(value, 0.0, 1.0);
float pos = value * 255.0;
int lo = int(floor(pos));
int hi = min(lo + 1, 255);
float f = pos - float(lo);
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
return a + f * (b - a);
}
void main() {
vec4 color = texture(u_image0, v_texCoord);
// GIMP order: per-channel curves first, then RGB master curve.
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
// dest = colors_curve( channel_curve( src ) )
float tmp_r = applyCurve(u_curve1, color.r);
float tmp_g = applyCurve(u_curve2, color.g);
float tmp_b = applyCurve(u_curve3, color.b);
color.r = applyCurve(u_curve0, tmp_r);
color.g = applyCurve(u_curve0, tmp_g);
color.b = applyCurve(u_curve0, tmp_b);
fragColor0 = vec4(color.rgb, color.a);
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -27,7 +27,6 @@ class AudioEncoderModel():
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
comfy.model_management.archive_model_dtypes(self.model)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

View File

@@ -83,8 +83,6 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
@@ -110,13 +108,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -151,7 +147,6 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@@ -237,7 +232,7 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing:
args = parser.parse_args()
@@ -265,6 +260,4 @@ else:
args.fast = set(args.fast)
def enables_dynamic_vram():
if args.enable_dynamic_vram:
return True
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[dict]]
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
gradient_stops: NotRequired[list[list[float]]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
class HiddenInputTypeDict(TypedDict):

View File

@@ -93,50 +93,6 @@ class IndexListCallbacks:
return {}
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
return None
cond_tensor = cond_value.cond
if temporal_dim >= cond_tensor.ndim:
return None
cond_size = cond_tensor.size(temporal_dim)
if temporal_scale == 1:
expected_size = x_in.size(window.dim) - temporal_offset
if cond_size != expected_size:
return None
if temporal_offset == 0 and temporal_scale == 1:
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
return cond_value._copy_with(sliced)
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0:
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
indices = [i for i in indices if 0 <= i]
else:
indices = list(window.index_list)
if not indices:
return None
if temporal_scale > 1:
scaled = []
for i in indices:
for k in range(temporal_scale):
si = i * temporal_scale + k
if si < cond_size:
scaled.append(si)
indices = scaled
if not indices:
return None
idx = tuple([slice(None)] * temporal_dim + [indices])
sliced = cond_tensor[idx].to(device)
return cond_value._copy_with(sliced)
@dataclass
class ContextSchedule:
name: str
@@ -221,17 +177,10 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item[cond_key] = result
handled = True
break
if not handled and self._model is not None:
result = self._model.resize_cond_for_context_window(
cond_key, cond_value, window, x_in, device,
retain_index_list=self.cond_retain_index_list)
if result is not None:
new_cond_item[cond_key] = result
handled = True
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
@@ -275,7 +224,6 @@ class IndexListContextHandler(ContextHandlerABC):
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))

View File

@@ -209,39 +209,3 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@@ -136,7 +136,16 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

View File

@@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor * m_mult
else:
for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
return tensor
@@ -223,19 +223,12 @@ class DoubleStreamBlock(nn.Module):
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
@@ -328,12 +321,6 @@ class SingleStreamBlock(nn.Module):
del qkv
q, k = self.norm(q, k, v)
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v

View File

@@ -31,8 +31,6 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])

View File

@@ -44,22 +44,6 @@ class FluxParams:
txt_norm: bool = False
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -154,7 +138,6 @@ class Flux(nn.Module):
y: Tensor,
guidance: Tensor = None,
control = None,
timestep_zero_index=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
@@ -181,9 +164,13 @@ class Flux(nn.Module):
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
img = out["img"]
txt = out["txt"]
img_ids = out["img_ids"]
@@ -195,24 +182,6 @@ class Flux(nn.Module):
else:
pe = None
vec_orig = vec
txt_vec = vec
extra_kwargs = {}
if timestep_zero_index is not None:
modulation_dims = []
batch = vec.shape[0] // 2
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
invert = invert_slices(timestep_zero_index, img.shape[1])
for s in invert:
modulation_dims.append((s[0], s[1], 0))
for s in timestep_zero_index:
modulation_dims.append((s[0], s[1], 1))
extra_kwargs["modulation_dims_img"] = modulation_dims
txt_vec = vec[:batch]
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
@@ -226,8 +195,7 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
**extra_kwargs)
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
@@ -245,8 +213,7 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
attn_mask=attn_mask,
transformer_options=transformer_options,
**extra_kwargs)
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -263,12 +230,6 @@ class Flux(nn.Module):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
extra_kwargs = {}
if timestep_zero_index is not None:
lambda a: 0 if a == 0 else a + txt.shape[1]
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
extra_kwargs["modulation_dims"] = modulation_dims_combined
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@@ -281,8 +242,7 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
**extra_kwargs)
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
@@ -293,7 +253,7 @@ class Flux(nn.Module):
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -304,11 +264,7 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
extra_kwargs = {}
if timestep_zero_index is not None:
extra_kwargs["modulation_dims"] = modulation_dims
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@@ -356,16 +312,13 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None:
ref_num_tokens = []
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents:
if ref_latents_method in ("index", "index_timestep_zero"):
if ref_latents_method == "index":
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
@@ -386,16 +339,9 @@ class Flux(nn.Module):
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@@ -403,6 +349,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@@ -343,7 +343,6 @@ class CrossAttention(nn.Module):
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
low_precision_attention=False,
)
out = self.out_proj(x)
@@ -413,7 +412,6 @@ class Attention(nn.Module):
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
low_precision_attention=False,
)
x = self.out_proj(x)

View File

@@ -681,33 +681,6 @@ class LTXAVModel(LTXVModel):
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax)
# Inject reference audio for ID-LoRA in-context conditioning
ref_audio = kwargs.get("ref_audio", None)
ref_audio_seq_len = 0
if ref_audio is not None:
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
if ref_tokens.shape[0] < ax.shape[0]:
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
ref_audio_seq_len = ref_tokens.shape[1]
B = ax.shape[0]
# Compute negative temporal positions matching ID-LoRA convention:
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
p = self.a_patchifier
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
time_offset = ref_end[-1].item() + tpl
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
additional_args["target_audio_seq_len"] = ax.shape[1]
ax = torch.cat([ref_tokens, ax], dim=1)
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)})
@@ -748,14 +721,6 @@ class LTXAVModel(LTXVModel):
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0 and a_timestep is not None:
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
target_len = kwargs.get("target_audio_seq_len")
if a_timestep.dim() <= 1:
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
@@ -990,13 +955,6 @@ class LTXAVModel(LTXVModel):
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]
# Trim reference audio tokens before unpatchification
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0:
ax = ax[:, ref_audio_seq_len:]
if a_embedded_timestep.shape[1] > 1:
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()

View File

@@ -23,11 +23,6 @@ class CausalConv3d(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(stride, int):
self.time_stride = stride
else:
self.time_stride = stride[0]
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
@@ -63,25 +58,16 @@ class CausalConv3d(nn.Module):
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
input_length = sum([piece.shape[2] for piece in pieces])
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
needs_caching = not is_end
if needs_caching and cache_length == 0:
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
if needs_caching and x.shape[2] >= cache_length:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

View File

@@ -11,7 +11,6 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
@@ -233,7 +232,10 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
@@ -244,14 +246,10 @@ class Encoder(nn.Module):
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
if sample is None or sample.shape[2] == 0:
return None
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if sample is None or sample.shape[2] == 0:
return None
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
@@ -283,35 +281,9 @@ class Encoder(nn.Module):
return sample
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
outputs = []
samples = [sample[:, :, :1, :, :]]
if sample.shape[2] > 1:
chunk_t = max(2, max_chunk_size // frame_size)
if chunk_t < 4:
chunk_t = 2
elif chunk_t < 8:
chunk_t = 4
else:
chunk_t = (chunk_t // 8) * 8
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
for chunk_idx, chunk in enumerate(samples):
if chunk_idx == len(samples) - 1:
mark_conv3d_ended(self)
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
output = self._forward_chunk(chunk)
if output is not None:
outputs.append(output)
return torch_cat_if_needed(outputs, dim=2)
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
@@ -324,23 +296,7 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2
def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module):
r"""
@@ -500,17 +456,6 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
ts, hs, ws, to = 1, 1, 1, 0
for block in self.up_blocks:
if isinstance(block, DepthToSpaceUpsample):
ts *= block.stride[0]
hs *= block.stride[1]
ws *= block.stride[2]
if block.stride[0] > 1:
to = to * block.stride[0] + 1
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
self.timestep_conditioning = timestep_conditioning
if timestep_conditioning:
@@ -532,62 +477,11 @@ class Decoder(nn.Module):
)
def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return
up_block = self.up_blocks[idx]
if ended:
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward_orig(
self,
sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0]
@@ -602,7 +496,6 @@ class Decoder(nn.Module):
)
timestep_shift_scale = None
scaled_timestep = None
if self.timestep_conditioning:
assert (
timestep is not None
@@ -630,18 +523,48 @@ class Decoder(nn.Module):
)
timestep_shift_scale = ada_values.unbind(dim=1)
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]
output = []
max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
return output_buffer
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
def forward(self, *args, **kwargs):
try:
@@ -765,25 +688,12 @@ class SpaceToDepthDownsample(nn.Module):
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True):
tid = threading.get_ident()
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
if cached_input is not None:
x = torch_cat_if_needed([cached_input, x], dim=2)
cached_input = None
if self.stride[0] == 2 and pad_first:
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
pad_first = False
if x.shape[2] < self.stride[0]:
cached_input = x
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return None
# skip connection
x_in = rearrange(
@@ -798,26 +708,15 @@ class SpaceToDepthDownsample(nn.Module):
# conv
x = self.conv(x, causal=causal)
if self.stride[0] == 2 and x.shape[2] == 1:
if cached_x is not None:
x = torch_cat_if_needed([cached_x, x], dim=2)
cached_x = None
else:
cached_x = x
x = None
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if x is not None:
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
cached = add_exchange_cache(x, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
x = x + x_in
return x
@@ -1150,8 +1049,6 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
comfy_has_chunked_io = True
def __init__(self, version=0, config=None):
super().__init__()
@@ -1294,15 +1191,14 @@ class VideoVAE(nn.Module):
}
return config
def encode(self, x, device=None):
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)
def decode(self, x, output_buffer=None):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)

View File

@@ -2,7 +2,6 @@ import torch
import torch.nn.functional as F
import torch.nn as nn
import comfy.ops
import comfy.model_management
import numpy as np
import math
@@ -82,7 +81,7 @@ class LowPassFilter1d(nn.Module):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
class UpSample1d(nn.Module):
@@ -126,7 +125,7 @@ class UpSample1d(nn.Module):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
@@ -191,7 +190,7 @@ class Snake(nn.Module):
self.eps = 1e-9
def forward(self, x):
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
a = self.alpha.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
a = torch.exp(a)
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
@@ -218,8 +217,8 @@ class SnakeBeta(nn.Module):
self.eps = 1e-9
def forward(self, x):
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
a = self.alpha.unsqueeze(0).unsqueeze(-1)
b = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
a = torch.exp(a)
b = torch.exp(b)
@@ -597,7 +596,7 @@ class _STFTFn(nn.Module):
y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real ** 2 + imag ** 2)
@@ -648,7 +647,7 @@ class MelSTFT(nn.Module):
"""
magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy

View File

@@ -372,8 +372,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except Exception as e:
model_management.raise_non_oom(e)
except model_management.OOM_EXCEPTION as e:
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:

View File

@@ -258,8 +258,7 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except Exception as e:
model_management.raise_non_oom(e)
except model_management.OOM_EXCEPTION as e:
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
@@ -315,8 +314,7 @@ def pytorch_attention(q, k, v):
try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except Exception as e:
model_management.raise_non_oom(e)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:

View File

@@ -169,8 +169,7 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except Exception as e:
model_management.raise_non_oom(e)
except model_management.OOM_EXCEPTION:
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)

View File

@@ -149,9 +149,6 @@ class Attention(nn.Module):
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
@@ -170,22 +167,15 @@ class Attention(nn.Module):
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options,
skip_reshape=True)
@@ -454,7 +444,6 @@ class QwenImageTransformer2DModel(nn.Module):
timestep_zero_index = None
if ref_latents is not None:
ref_num_tokens = []
h = 0
w = 0
index = 0
@@ -485,16 +474,16 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@@ -506,18 +495,6 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):

View File

@@ -1,725 +0,0 @@
from collections import OrderedDict
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
COCO_CLASSES = [
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat',
'dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack',
'umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball',
'kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket',
'bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple',
'sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair',
'couch','potted plant','bed','dining table','toilet','tv','laptop','mouse',
'remote','keyboard','cell phone','microwave','oven','toaster','sink',
'refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush',
]
# ---------------------------------------------------------------------------
# HGNetv2 backbone
# ---------------------------------------------------------------------------
class ConvBNAct(nn.Module):
"""Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem)."""
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None):
super().__init__()
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
self.act = nn.ReLU() if use_act else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class LightConvBNAct(nn.Module):
def __init__(self, ic, oc, k, device=None, dtype=None, operations=None):
super().__init__()
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations)
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations)
def forward(self, x):
return self.conv2(self.conv1(x))
class _StemBlock(nn.Module):
def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None):
super().__init__()
self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations)
# stem2a/stem2b use kernel=2, stride=1, no internal padding;
# padding is applied manually in forward (matching PaddlePaddle original)
self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations)
self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations)
self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations)
self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations)
self.pool = nn.MaxPool2d(2, 1, ceil_mode=True)
def forward(self, x):
x = self.stem1(x)
x = F.pad(x, (0, 1, 0, 1)) # pad before pool and stem2a
x2 = self.stem2a(x)
x2 = F.pad(x2, (0, 1, 0, 1)) # pad before stem2b
x2 = self.stem2b(x2)
x1 = self.pool(x)
return self.stem4(self.stem3(torch.cat([x1, x2], 1)))
class _HG_Block(nn.Module):
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None):
super().__init__()
self.residual = residual
if light:
self.layers = nn.ModuleList(
[LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
else:
self.layers = nn.ModuleList(
[ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
total = ic + layer_num * mc
self.aggregation = nn.Sequential(
ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations),
ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations))
def forward(self, x):
identity = x
outs = [x]
for layer in self.layers:
x = layer(x)
outs.append(x)
x = self.aggregation(torch.cat(outs, 1))
return x + identity if self.residual else x
class _HG_Stage(nn.Module):
# config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None):
super().__init__()
if downsample:
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations)
else:
self.downsample = nn.Identity()
self.blocks = nn.Sequential(*[
_HG_Block(ic if i == 0 else oc, mc, oc, layer_num,
k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations)
for i in range(num_blocks)
])
def forward(self, x):
return self.blocks(self.downsample(x))
class HGNetv2(nn.Module):
# B5 config: stem=[3,32,64], stages=[ic, mc, oc, blocks, down, light, k, layers]
_STAGE_CFGS = [[64, 64, 128, 1, False, False, 3, 6],
[128, 128, 512, 2, True, False, 3, 6],
[512, 256, 1024, 5, True, True, 5, 6],
[1024,512, 2048, 2, True, True, 5, 6]]
def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None):
super().__init__()
self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations)
self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS])
self.return_idx = list(return_idx)
self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx]
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.stem(x)
outs = []
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.return_idx:
outs.append(x)
return outs
# ---------------------------------------------------------------------------
# Encoder — HybridEncoder (dfine version: RepNCSPELAN4 + SCDown PAN)
# ---------------------------------------------------------------------------
class ConvNormLayer(nn.Module):
"""Conv→act (expects pre-fused BN weights)."""
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None):
super().__init__()
p = (k - 1) // 2 if padding is None else padding
self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU() if act == 'silu' else nn.Identity()
def forward(self, x):
return self.act(self.conv(x))
class VGGBlock(nn.Module):
"""Rep-VGG block (expects pre-fused weights)."""
def __init__(self, ic, oc, device=None, dtype=None, operations=None):
super().__init__()
self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.conv(x))
class CSPLayer(nn.Module):
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None):
super().__init__()
h = int(oc * expansion)
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)])
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity()
def forward(self, x):
return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x))
class RepNCSPELAN4(nn.Module):
"""CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder."""
def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None):
super().__init__()
self.c = c3 // 2
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
def forward(self, x):
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
class SCDown(nn.Module):
"""Separable conv downsampling used in HybridEncoder PAN bottom-up path."""
def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None):
super().__init__()
self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations)
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations)
def forward(self, x):
return self.cv2(self.cv1(x))
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, query, key, value, attn_mask=None):
optimized_attention = optimized_attention_for_device(query.device, False, small_input=True)
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask)
return self.out_proj(out)
class _TransformerEncoderLayer(nn.Module):
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.activation = nn.GELU()
def forward(self, src, src_mask=None, pos_embed=None):
q = k = src if pos_embed is None else src + pos_embed
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)
src = self.norm1(src + src2)
src2 = self.linear2(self.activation(self.linear1(src)))
return self.norm2(src + src2)
class _TransformerEncoder(nn.Module):
"""Thin wrapper so state-dict keys are encoder.0.layers.N.*"""
def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
_TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
def forward(self, src, src_mask=None, pos_embed=None):
for layer in self.layers:
src = layer(src, src_mask=src_mask, pos_embed=pos_embed)
return src
class HybridEncoder(nn.Module):
def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1,
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
super().__init__()
self.in_channels = list(in_channels)
self.feat_strides = list(feat_strides)
self.hidden_dim = hidden_dim
self.use_encoder_idx = list(use_encoder_idx)
self.pe_temperature = pe_temperature
self.eval_spatial_size = eval_spatial_size
self.out_channels = [hidden_dim] * len(in_channels)
self.out_strides = list(feat_strides)
# channel projection (expects pre-fused weights)
self.input_proj = nn.ModuleList([
nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))
for ch in in_channels
])
# AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.*
self.encoder = nn.ModuleList([
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
for _ in range(len(use_encoder_idx))
])
nb = round(3 * depth_mult)
exp = expansion
# top-down FPN (dfine: lateral conv has no act)
self.lateral_convs = nn.ModuleList(
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
self.fpn_blocks = nn.ModuleList(
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
# bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2)
self.downsample_convs = nn.ModuleList(
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations))
for _ in range(len(in_channels) - 1)])
self.pan_blocks = nn.ModuleList(
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
# cache positional embeddings for fixed spatial size
if eval_spatial_size:
for idx in self.use_encoder_idx:
stride = self.feat_strides[idx]
pe = self._build_pe(eval_spatial_size[1] // stride,
eval_spatial_size[0] // stride,
hidden_dim, pe_temperature)
setattr(self, f'pos_embed{idx}', pe)
@staticmethod
def _build_pe(w, h, dim=256, temp=10000.):
assert dim % 4 == 0
gw = torch.arange(w, dtype=torch.float32)
gh = torch.arange(h, dtype=torch.float32)
gw, gh = torch.meshgrid(gw, gh, indexing='ij')
pdim = dim // 4
omega = 1. / (temp ** (torch.arange(pdim, dtype=torch.float32) / pdim))
ow = gw.flatten()[:, None] @ omega[None]
oh = gh.flatten()[:, None] @ omega[None]
return torch.cat([ow.sin(), ow.cos(), oh.sin(), oh.cos()], 1)[None]
def forward(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
for i, enc_idx in enumerate(self.use_encoder_idx):
h, w = proj[enc_idx].shape[2:]
src = proj[enc_idx].flatten(2).permute(0, 2, 1)
pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype)
for layer in self.encoder[i].layers:
src = layer(src, pos_embed=pe)
proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
n = len(self.in_channels)
inner = [proj[-1]]
for k in range(n - 1, 0, -1):
j = n - 1 - k
top = self.lateral_convs[j](inner[0])
inner[0] = top
up = F.interpolate(top, scale_factor=2., mode='nearest')
inner.insert(0, self.fpn_blocks[j](torch.cat([up, proj[k - 1]], 1)))
outs = [inner[0]]
for k in range(n - 1):
outs.append(self.pan_blocks[k](
torch.cat([self.downsample_convs[k](outs[-1]), inner[k + 1]], 1)))
return outs
# ---------------------------------------------------------------------------
# Decoder — DFINETransformer
# ---------------------------------------------------------------------------
def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int]) -> torch.Tensor:
"""
value : list of per-level tensors [bs*n_head, c, h_l, w_l]
sampling_locations: [bs, Lq, n_head, sum(pts), 2] in [0,1]
attention_weights : [bs, Lq, n_head, sum(pts)]
"""
_, c = value[0].shape[:2] # bs*n_head, c
_, Lq, n_head, _, _ = sampling_locations.shape
bs = sampling_locations.shape[0]
n_h = n_head
grids = (2 * sampling_locations - 1) # [bs, Lq, n_head, sum_pts, 2]
grids = grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # [bs*n_head, Lq, sum_pts, 2]
grids_per_lvl = grids.split(num_points_list, dim=2) # list of [bs*n_head, Lq, pts_l, 2]
sampled = []
for lvl, (h, w) in enumerate(spatial_shapes):
val_l = value[lvl].reshape(bs * n_h, c, h, w)
sv = F.grid_sample(val_l, grids_per_lvl[lvl], mode='bilinear', padding_mode='zeros', align_corners=False)
sampled.append(sv) # sv: [bs*n_head, c, Lq, pts_l]
attn = attention_weights.permute(0, 2, 1, 3) # [bs, n_head, Lq, sum_pts]
attn = attn.flatten(0, 1).unsqueeze(1) # [bs*n_head, 1, Lq, sum_pts]
out = (torch.cat(sampled, -1) * attn).sum(-1) # [bs*n_head, c, Lq]
out = out.reshape(bs, n_h * c, Lq)
return out.permute(0, 2, 1) # [bs, Lq, hidden]
class MSDeformableAttention(nn.Module):
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None):
super().__init__()
self.embed_dim, self.num_heads = embed_dim, num_heads
self.head_dim = embed_dim // num_heads
pts = num_points if isinstance(num_points, list) else [num_points] * num_levels
self.num_points_list = pts
self.offset_scale = offset_scale
total = num_heads * sum(pts)
self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32))
self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype)
self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype)
def forward(self, query, ref_pts, value, spatial_shapes):
bs, Lq = query.shape[:2]
offsets = self.sampling_offsets(query).reshape(
bs, Lq, self.num_heads, sum(self.num_points_list), 2)
attn_w = F.softmax(
self.attention_weights(query).reshape(
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
scale = self.num_points_scale.to(query).unsqueeze(-1)
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
class Gate(nn.Module):
def __init__(self, d_model, device=None, dtype=None, operations=None):
super().__init__()
self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype)
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, x1, x2):
g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1)
return self.norm(g1 * x1 + g2 * x2)
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None):
super().__init__()
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = nn.SiLU()(layer(x)) if i < len(self.layers) - 1 else layer(x)
return x
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.activation = nn.ReLU()
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
q = k = target if query_pos is None else target + query_pos
t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask)
target = self.norm1(target + t2)
t2 = self.cross_attn(
target if query_pos is None else target + query_pos,
ref_pts, value, spatial_shapes)
target = self.gateway(target, t2)
t2 = self.linear2(self.activation(self.linear1(target)))
target = self.norm3((target + t2).clamp(-65504, 65504))
return target
# ---------------------------------------------------------------------------
# FDR utilities
# ---------------------------------------------------------------------------
def weighting_function(reg_max, up, reg_scale):
"""Non-uniform weighting function W(n) for FDR box regression."""
ub1 = (abs(up[0]) * abs(reg_scale)).item()
ub2 = ub1 * 2
step = (ub1 + 1) ** (2 / (reg_max - 2))
left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
right = [ (step ** i) - 1 for i in range(1, reg_max // 2)]
vals = [-ub2] + left + [0] + right + [ub2]
return torch.tensor(vals, dtype=up.dtype, device=up.device)
def distance2bbox(points, distance, reg_scale):
"""Decode edge-distances → cxcywh boxes."""
rs = abs(reg_scale).to(dtype=points.dtype)
x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs)
y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs)
x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs)
y2 = points[..., 1] + (0.5 * rs + distance[..., 3]) * (points[..., 3] / rs)
x0, y0, x1_, y1_ = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
return torch.stack([x0, y0, x1_, y1_], -1)
class Integral(nn.Module):
"""Sum Pr(n)·W(n) over the distribution bins."""
def __init__(self, reg_max=32):
super().__init__()
self.reg_max = reg_max
def forward(self, x, project):
shape = x.shape
x = F.softmax(x.reshape(-1, self.reg_max + 1), 1)
x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4)
return x.reshape(list(shape[:-1]) + [-1])
class LQE(nn.Module):
"""Location Quality Estimator — refines class scores using corner distribution."""
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None):
super().__init__()
self.k, self.reg_max = k, reg_max
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations)
def forward(self, scores, pred_corners):
B, L, _ = pred_corners.shape
prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), -1)
topk, _ = prob.topk(self.k, -1)
stat = torch.cat([topk, topk.mean(-1, keepdim=True)], -1)
return scores + self.reg_conf(stat.reshape(B, L, -1))
class TransformerDecoder(nn.Module):
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.nhead = nhead
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max
self.layers = nn.ModuleList([
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations)
for _ in range(self.eval_idx + 1)
])
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)])
self.register_buffer('project', weighting_function(reg_max, up, reg_scale))
def _value_op(self, memory, spatial_shapes):
"""Reshape memory to per-level value tensors for deformable attention."""
c = self.hidden_dim // self.nhead
split = [h * w for h, w in spatial_shapes]
val = memory.reshape(memory.shape[0], memory.shape[1], self.nhead, c) # memory: [bs, sum(h*w), hidden_dim]
# → [bs, n_head, c, sum_hw]
val = val.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, sum_hw]
return val.split(split, dim=-1) # list of [bs*n_head, c, h_l*w_l]
def forward(self, target, ref_pts_unact, memory, spatial_shapes, bbox_head, score_head, query_pos_head, pre_bbox_head, integral):
val_split_flat = self._value_op(memory, spatial_shapes) # pre-split value for deformable attention
# reshape to [bs*n_head, c, h_l, w_l]
value = []
for lvl, (h, w) in enumerate(spatial_shapes):
v = val_split_flat[lvl] # [bs*n_head, c, h*w]
value.append(v.reshape(v.shape[0], v.shape[1], h, w))
ref_pts = F.sigmoid(ref_pts_unact)
output = target
output_detach = pred_corners_undetach = 0
dec_bboxes, dec_logits = [], []
for i, layer in enumerate(self.layers):
ref_input = ref_pts.unsqueeze(2) # [bs, Lq, 1, 4]
query_pos = query_pos_head(ref_pts).clamp(-10, 10)
output = layer(output, ref_input, value, spatial_shapes, query_pos=query_pos)
if i == 0:
ref_unact = ref_pts.clamp(1e-5, 1 - 1e-5)
ref_unact = torch.log(ref_unact / (1 - ref_unact))
pre_bboxes = F.sigmoid(pre_bbox_head(output) + ref_unact)
ref_pts_initial = pre_bboxes.detach()
pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach
inter_ref_bbox = distance2bbox(ref_pts_initial, integral(pred_corners, self.project), self.reg_scale)
if i == self.eval_idx:
scores = score_head[i](output)
scores = self.lqe_layers[i](scores, pred_corners)
dec_bboxes.append(inter_ref_bbox)
dec_logits.append(scores)
break
pred_corners_undetach = pred_corners
ref_pts = inter_ref_bbox.detach()
output_detach = output.detach()
return torch.stack(dec_bboxes), torch.stack(dec_logits)
class DFINETransformer(nn.Module):
def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32],
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32,
reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
super().__init__()
assert len(feat_strides) == len(feat_channels)
self.hidden_dim = hidden_dim
self.num_queries = num_queries
self.num_levels = num_levels
self.eps = eps
self.eval_spatial_size = eval_spatial_size
self.feat_strides = list(feat_strides)
for i in range(num_levels - len(feat_strides)):
self.feat_strides.append(feat_strides[-1] * 2 ** (i + 1))
# input projection (expects pre-fused weights)
self.input_proj = nn.ModuleList()
for ch in feat_channels:
if ch == hidden_dim:
self.input_proj.append(nn.Identity())
else:
self.input_proj.append(nn.Sequential(OrderedDict([
('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])))
in_ch = feat_channels[-1]
for i in range(num_levels - len(feat_channels)):
self.input_proj.append(nn.Sequential(OrderedDict([
('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim,
hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))])))
in_ch = hidden_dim
# FDR parameters (non-trainable placeholders, set from config)
self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels
self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts,
num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations)
self.enc_output = nn.Sequential(OrderedDict([
('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)),
('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))]))
self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx
self.dec_score_head = nn.ModuleList(
[operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)])
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
self.dec_bbox_head = nn.ModuleList(
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations)
for _ in range(self.eval_idx_ + 1)])
self.integral = Integral(reg_max)
if eval_spatial_size:
# Register as buffers so checkpoint values override the freshly-computed defaults
anchors, valid_mask = self._gen_anchors()
self.register_buffer('anchors', anchors)
self.register_buffer('valid_mask', valid_mask)
def _gen_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device='cpu'):
if spatial_shapes is None:
h0, w0 = self.eval_spatial_size
spatial_shapes = [[int(h0 / s), int(w0 / s)] for s in self.feat_strides]
anchors = []
for lvl, (h, w) in enumerate(spatial_shapes):
gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
gxy = (torch.stack([gx, gy], -1).float() + 0.5) / torch.tensor([w, h], dtype=dtype)
wh = torch.ones_like(gxy) * grid_size * (2. ** lvl)
anchors.append(torch.cat([gxy, wh], -1).reshape(-1, h * w, 4))
anchors = torch.cat(anchors, 1).to(device)
valid_mask = ((anchors > self.eps) & (anchors < 1 - self.eps)).all(-1, keepdim=True)
anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.full_like(anchors, float('inf')))
return anchors, valid_mask
def _encoder_input(self, feats: List[torch.Tensor]):
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
for i in range(len(feats), self.num_levels):
proj.append(self.input_proj[i](feats[-1] if i == len(feats) else proj[-1]))
flat, shapes = [], []
for f in proj:
_, _, h, w = f.shape
flat.append(f.flatten(2).permute(0, 2, 1))
shapes.append([h, w])
return torch.cat(flat, 1), shapes
def _decoder_input(self, memory: torch.Tensor):
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
if memory.shape[0] > 1:
anchors = anchors.repeat(memory.shape[0], 1, 1)
mem = valid_mask.to(memory) * memory
out_mem = self.enc_output(mem)
logits = self.enc_score_head(out_mem)
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
idx_e = idx.unsqueeze(-1)
topk_mem = out_mem.gather(1, idx_e.expand(-1, -1, out_mem.shape[-1]))
topk_anc = anchors.gather(1, idx_e.expand(-1, -1, anchors.shape[-1]))
topk_ref = self.enc_bbox_head(topk_mem) + topk_anc
return topk_mem.detach(), topk_ref.detach()
def forward(self, feats: List[torch.Tensor]):
memory, shapes = self._encoder_input(feats)
content, ref = self._decoder_input(memory)
out_bboxes, out_logits = self.decoder(
content, ref, memory, shapes,
self.dec_bbox_head, self.dec_score_head,
self.query_pos_head, self.pre_bbox_head, self.integral)
return {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
class RTv4(nn.Module):
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.device = device
self.dtype = dtype
self.operations = operations
self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations)
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations)
self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries,
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations)
self.num_classes = num_classes
self.num_queries = num_queries
self.load_device = comfy.model_management.get_torch_device()
def _forward(self, x: torch.Tensor):
return self.decoder(self.encoder(self.backbone(x)))
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
logits = outputs['pred_logits']
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
scores = F.sigmoid(logits)
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
labels = idx % self.num_classes
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
return self.postprocess(outputs, orig_size)

View File

@@ -99,7 +99,7 @@ class Resample(nn.Module):
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
@@ -109,7 +109,22 @@ class Resample(nn.Module):
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
@@ -130,24 +145,19 @@ class Resample(nn.Module):
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :]
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
deferred_x = feat_cache[idx + 1]
if deferred_x is not None:
x = torch.cat([deferred_x, x], 2)
feat_cache[idx + 1] = None
if x.shape[2] == 1 and not final:
feat_cache[idx + 1] = x
x = None
feat_idx[0] += 2
feat_idx[0] += 1
return x
@@ -167,12 +177,19 @@ class ResidualBlock(nn.Module):
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
def forward(self, x, feat_cache=None, feat_idx=[0]):
old_x = x
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -196,7 +213,7 @@ class AttentionBlock(nn.Module):
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
@@ -266,10 +283,17 @@ class Encoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -279,16 +303,14 @@ class Encoder3d(nn.Module):
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, final=final)
if x is None:
return None
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, final=final)
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
@@ -296,7 +318,14 @@ class Encoder3d(nn.Module):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -360,48 +389,18 @@ class Decoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1))
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
for frame_idx in range(0, x.shape[2], 2):
self.run_up(
layer_idx + 1,
[x[:, :, frame_idx:frame_idx + 2, :, :]],
feat_cache,
feat_idx.copy(),
out_chunks,
)
del x
return
next_x_ref = [x]
del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -410,21 +409,42 @@ class Decoder3d(nn.Module):
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
out_chunks = []
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
return out_chunks
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_cache_layers(model):
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
if isinstance(m, CausalConv3d):
count += 1
return count
@@ -462,12 +482,11 @@ class WanVAE(nn.Module):
conv_idx = [0]
## cache
t = x.shape[2]
t = 1 + ((t - 1) // 4) * 4
iter_ = 1 + (t - 1) // 2
iter_ = 1 + (t - 1) // 4
feat_map = None
if iter_ > 1:
feat_map = [None] * count_cache_layers(self.encoder)
## 对encode输入的x按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
feat_map = [None] * count_conv3d(self.encoder)
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
conv_idx = [0]
if i == 0:
@@ -477,23 +496,20 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=feat_map,
feat_idx=conv_idx,
final=(i == (iter_ - 1)))
if out_ is None:
continue
feat_idx=conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
return mu
def decode(self, z):
conv_idx = [0]
# z: [b,c,t,h,w]
iter_ = 1 + z.shape[2] // 2
iter_ = z.shape[2]
feat_map = None
if iter_ > 1:
feat_map = [None] * count_cache_layers(self.decoder)
feat_map = [None] * count_conv3d(self.decoder)
x = self.conv2(z)
for i in range(iter_):
conv_idx = [0]
@@ -504,8 +520,8 @@ class WanVAE(nn.Module):
feat_idx=conv_idx)
else:
out_ = self.decoder(
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
x[:, :, i:i + 1, :, :],
feat_cache=feat_map,
feat_idx=conv_idx)
out += out_
return torch.cat(out, 2)
out = torch.cat([out, out_], 2)
return out

View File

@@ -99,9 +99,6 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
if tp > 0 and not k.startswith("clip_"):
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False

View File

@@ -1,71 +1,9 @@
import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
offset: int
size: int
def read_tensor_file_slice_into(tensor, destination):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
return False
if info.size == 0:
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()
class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype
@@ -141,17 +79,3 @@ def interpret_gathered_like(tensors, gathered):
return dest_views
aimdo_enabled = False
extra_ram_release_callback = None
RAM_CACHE_HEADROOM = 0
def set_ram_cache_release_state(callback, headroom):
global extra_ram_release_callback
global RAM_CACHE_HEADROOM
extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target):
if extra_ram_release_callback is None:
return 0
return extra_ram_release_callback(target)

View File

@@ -21,7 +21,6 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
@@ -52,7 +51,6 @@ import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.model_management
import comfy.patcher_extension
@@ -287,12 +285,6 @@ class BaseModel(torch.nn.Module):
return data
return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
Use comfy.context_windows.slice_cond() for common cases."""
return None
def extra_conds(self, **kwargs):
out = {}
concat_cond = self.concat_cond(**kwargs)
@@ -891,7 +883,7 @@ class Flux(BaseModel):
return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs):
return kwargs.get("pooled_output", None)
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -938,10 +930,9 @@ class LongCatImage(Flux):
transformer_options = transformer_options.copy()
rope_opts = transformer_options.get("rope_options", {})
rope_opts = dict(rope_opts)
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
rope_opts.setdefault("shift_t", 1.0)
rope_opts.setdefault("shift_y", pe_len)
rope_opts.setdefault("shift_x", pe_len)
rope_opts.setdefault("shift_y", 512.0)
rope_opts.setdefault("shift_x", 512.0)
transformer_options["rope_options"] = rope_opts
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
@@ -1062,10 +1053,6 @@ class LTXAV(BaseModel):
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
ref_audio = kwargs.get("ref_audio", None)
if ref_audio is not None:
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@@ -1388,11 +1375,6 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "vace_context":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
@@ -1445,11 +1427,6 @@ class WAN21_HuMo(WAN21):
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
@@ -1467,13 +1444,6 @@ class WAN22_Animate(WAN21):
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "face_pixel_values":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
if cond_key == "pose_latents":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@@ -1510,11 +1480,6 @@ class WAN22_S2V(WAN21):
out['reference_motion'] = reference_motion.shape
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
@@ -1958,7 +1923,3 @@ class Kandinsky5Image(Kandinsky5):
def concat_cond(self, **kwargs):
return None
class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)

View File

@@ -1,5 +1,4 @@
import json
import comfy.memory_management
import comfy.supported_models
import comfy.supported_models_base
import comfy.utils
@@ -698,12 +697,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["audio_model"] = "ace1.5"
return dit_config
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4
dit_config = {}
dit_config["image_model"] = "RT_DETR_v4"
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -1125,13 +1118,8 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
new[:old_weight.shape[0]] = old_weight
old_weight = new
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
old_weight = old_weight.clone()
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
if comfy.memory_management.aimdo_enabled:
weight = weight.clone()
old_weight = weight
w = weight
w[:] = fun(weight)

View File

@@ -55,7 +55,6 @@ total_vram = 0
# Training Related State
in_training = False
training_fp8_bwd = False
def get_supported_float8_types():
@@ -271,23 +270,6 @@ try:
except:
OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
discard_cuda_async_error()
return True
return False
def raise_non_oom(e):
if not is_oom(e):
raise e
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
if args.disable_xformers:
@@ -401,7 +383,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@@ -506,28 +488,6 @@ def module_size(module):
module_mem += t.nbytes
return module_mem
def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem
class LoadedModel:
def __init__(self, model):
self._set_model(model)
@@ -542,7 +502,6 @@ class LoadedModel:
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
self._patcher_finalizer.atexit = False
def _switch_parent(self):
model = self._parent_model()
@@ -556,9 +515,6 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self):
return self.model.loaded_size()
@@ -589,7 +545,6 @@ class LoadedModel:
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
self.model_finalizer.atexit = False
return real_model
def should_reload_model(self, force_patch_weights=False):
@@ -661,7 +616,7 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
cleanup_models_gc()
unloaded_model = []
can_unload = []
@@ -669,19 +624,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
if device is None or shift_model.device == device:
if shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
for x in sorted(can_unload):
i = x[-1]
memory_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram()
ram_to_free = 1e32
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
@@ -690,25 +644,16 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
if ram_to_free > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
if len(unloaded_model) > 0:
soft_empty_cache()
elif device is not None:
else:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
@@ -767,27 +712,17 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
#want to do.
#FIXME: This should subtract off the to_load current pin consumption.
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
for device in total_memory_required:
if device != torch.device("cpu"):
@@ -1004,7 +939,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:
@@ -1053,12 +988,6 @@ def intermediate_device():
else:
return torch.device("cpu")
def intermediate_dtype():
if args.fp16_intermediates:
return torch.float16
else:
return torch.float32
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
@@ -1219,7 +1148,6 @@ def reset_cast_buffers():
LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear()
soft_empty_cache()
@@ -1279,11 +1207,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0)
if tensor is None:
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking)
@@ -1326,9 +1249,9 @@ MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
@@ -1339,7 +1262,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
except RuntimeError:
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
@@ -1403,6 +1326,8 @@ def unpin_memory(tensor):
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
else:
logging.warning("Unpin error.")
@@ -1719,19 +1644,6 @@ def supports_nvfp4_compute(device=None):
return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
@@ -1754,16 +1666,12 @@ def lora_compute_dtype(device):
return dtype
def synchronize():
if cpu_mode():
return
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
if cpu_mode():
return
global cpu_state
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()

View File

@@ -297,8 +297,8 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model)
return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory
@@ -599,27 +599,6 @@ class ModelPatcher:
return models
def model_patches_call_function(self, function_name="cleanup", arguments={}):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], function_name):
getattr(patch_list[i], function_name)(**arguments)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], function_name):
getattr(patch_list[k], function_name)(**arguments)
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, function_name):
getattr(wrap_func, function_name)(**arguments)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
@@ -736,8 +715,8 @@ class ModelPatcher:
default = True # default random weights in non leaf modules
break
if default and default_device is not None:
for param_name, param in params.items():
param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None))
for param in params.values():
param.data = param.data.to(device=default_device)
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
@@ -1063,10 +1042,6 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
def pinned_memory_size(self):
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload):
pass
@@ -1087,7 +1062,6 @@ class ModelPatcher:
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def cleanup(self):
self.model_patches_call_function(function_name="cleanup")
self.clean_hooks()
if hasattr(self.model, "current_patcher"):
self.model.current_patcher = None
@@ -1657,16 +1631,6 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:

View File

@@ -80,21 +80,6 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
offload_stream = None
xfer_dest = None
@@ -306,40 +291,10 @@ class CastWeightBiasOp:
bias_function = []
class disable_weight_init:
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
bias_shape=None):
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k, v in state_dict.items():
key = k[prefix_len:]
if key == "weight":
if not assign_to_params_buffers:
v = v.clone()
module.weight = torch.nn.Parameter(v, requires_grad=False)
elif bias_shape is not None and key == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
module.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
if module.weight is None:
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
missing_keys.append(prefix + "weight")
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
missing_keys.append(prefix + "bias")
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
super().__init__(in_features, out_features, bias, device, dtype)
return
@@ -360,21 +315,32 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.in_features, self.out_features),
bias_shape=(self.out_features,),
)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k,v in state_dict.items():
if k[prefix_len:] == "weight":
if not assign_to_params_buffers:
v = v.clone()
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
def reset_parameters(self):
@@ -566,53 +532,6 @@ class disable_weight_init:
return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
return
torch.nn.Module.__init__(self)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Keep shape/dtype visible for module introspection without reserving storage.
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
self.weight = torch.nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
requires_grad=False,
)
self.bias = None
self.weight_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.num_embeddings, self.embedding_dim),
)
def reset_parameters(self):
self.bias = None
return None
@@ -741,29 +660,23 @@ class fp8_ops(manual_cast):
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear, cublas_half_matmul
from cublas_ops import CublasLinear
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass
if CUBLAS_IS_AVAILABLE:
class cublas_ops(manual_cast):
class Linear(CublasLinear, manual_cast.Linear):
class cublas_ops(disable_weight_init):
class Linear(CublasLinear, disable_weight_init.Linear):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
return super().forward(input)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
return super().forward(*args, **kwargs)
# ==============================================================================
# Mixed Precision Operations
@@ -776,104 +689,6 @@ from .quant_ops import (
)
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
When training_fp8_bwd is enabled:
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
- Cached input is FP8 (half the memory of bf16)
When training_fp8_bwd is disabled:
- Forward: quantize input per layout, use quantized matmul
- Backward: dequantize weight to compute_dtype, use standard matmul
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input for forward (same layout as weight)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
output = torch.nn.functional.linear(q_input, w, b)
# Unflatten output to match original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
# Save for backward
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
if ctx.fp8_bwd:
# Cache FP8 quantized input — half the memory of bf16
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
ctx.q_input = q_input # already FP8, reuse
else:
# NVFP4 or other layout — quantize input to FP8 for backward
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
ctx.save_for_backward(weight)
else:
ctx.q_input = None
ctx.save_for_backward(input_float, weight)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Value casting — only difference between fp8 and non-fp8 paths
if ctx.fp8_bwd:
weight, = ctx.saved_tensors
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
weight_mm = weight
elif isinstance(weight, QuantizedTensor):
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
else:
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
input_mm = ctx.q_input
else:
input_float, weight = ctx.saved_tensors
# Standard tensors → torch.mm does regular matmul
grad_mm = grad_2d
if isinstance(weight, QuantizedTensor):
weight_mm = weight.dequantize().to(compute_dtype)
else:
weight_mm = weight.to(compute_dtype)
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
# Computation — same for both paths, dispatch handles the rest
grad_input = torch.mm(grad_mm, weight_mm)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
grad_weight = None
if ctx.weight_requires_grad:
grad_weight = torch.mm(grad_mm.t(), input_mm)
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
@@ -928,7 +743,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
@@ -966,22 +780,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@@ -1035,9 +833,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
@@ -1072,37 +867,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
_use_quantized = (
getattr(self, 'layout_type', None) is not None and
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)
uncast_bias_weight(self, weight, bias, offload_stream)
return output
# Inference path (unchanged)
if _use_quantized:
len(self.weight_function) == 0 and len(self.bias_function) == 0):
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@@ -1150,10 +918,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items():
if param is None:
continue
p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
@@ -1164,15 +929,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")

View File

@@ -1,8 +1,6 @@
import torch
import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
import psutil
from comfy.cli_args import args
@@ -13,37 +11,19 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
#we split the difference and assume half the RAM cache headroom is for us
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin):
module._pin = pin
else:
module.pin_failed = True
return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
except RuntimeError:
module.pin_failed = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
comfy.model_management.unpin_memory(module._pin)
del module._pin
del module._pin_hostbuf
return size

View File

@@ -1,33 +1,38 @@
import torch
import logging
from comfy.cli_args import args
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op,
register_layout_class,
get_layout_class,
)
_CK_AVAILABLE = True
if torch.version.cuda is None:
ck.registry.disable("cuda")
else:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
if args.cpu:
_CK_AVAILABLE = False
else:
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op,
register_layout_class,
get_layout_class,
)
_CK_AVAILABLE = True
if torch.version.cuda is None:
ck.registry.disable("cuda")
else:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
_CK_AVAILABLE = False
if not _CK_AVAILABLE:
class QuantizedTensor:
pass
@@ -43,18 +48,6 @@ except ImportError as e:
def get_layout_class(name):
return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float
# ==============================================================================
@@ -96,31 +89,6 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@@ -174,8 +142,6 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@@ -196,14 +162,6 @@ QUANT_ALGOS = {
},
}
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ==============================================================================
# Re-exports for backward compatibility

View File

@@ -8,12 +8,12 @@ import comfy.nested_tensor
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
samples = samples.to(comfy.model_management.intermediate_device())
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
samples = samples.to(comfy.model_management.intermediate_device())
return samples

View File

@@ -985,8 +985,8 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
noise = noise.to(device)
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
@@ -1028,7 +1028,6 @@ class CFGGuider:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else:
denoise_mask = denoise_masks[0]
denoise_mask = denoise_mask.float()
self.conds = {}
for k in self.original_conds:

View File

@@ -61,7 +61,6 @@ import comfy.text_encoders.newbie
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.model_patcher
import comfy.lora
@@ -280,6 +279,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@@ -423,13 +425,13 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@@ -453,7 +455,7 @@ class VAE:
self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
@@ -837,6 +839,9 @@ class VAE:
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
@@ -866,16 +871,13 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
def vae_output_dtype(self):
return model_management.intermediate_dtype()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@@ -885,16 +887,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@@ -903,7 +905,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@@ -912,7 +914,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
@@ -921,7 +923,7 @@ class VAE:
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
@@ -930,7 +932,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}):
@@ -946,25 +948,13 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
@@ -973,7 +963,6 @@ class VAE:
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@@ -1034,19 +1023,13 @@ class VAE:
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
@@ -1055,7 +1038,6 @@ class VAE:
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
@@ -1223,11 +1205,6 @@ class TEModel(Enum):
QWEN3_8B = 20
QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
QWEN35_08B = 23
QWEN35_2B = 24
QWEN35_4B = 25
QWEN35_9B = 26
QWEN35_27B = 27
def detect_te_model(sd):
@@ -1267,17 +1244,6 @@ def detect_te_model(sd):
return TEModel.QWEN25_3B
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
weight = sd['model.language_model.layers.0.input_layernorm.weight']
if weight.shape[0] == 1024:
return TEModel.QWEN35_08B
if weight.shape[0] == 2560:
return TEModel.QWEN35_4B
if weight.shape[0] == 4096:
return TEModel.QWEN35_9B
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@@ -1310,12 +1276,11 @@ def t5xxl_detect(clip_data):
return {}
def llama_detect(clip_data):
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
weight_name = "model.layers.0.self_attn.k_proj.weight"
for sd in clip_data:
for weight_name in weight_names:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
return {}
@@ -1443,11 +1408,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
@@ -1736,16 +1696,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
"""
dtype = model_options.get("dtype", None)
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
#Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0:
sd = temp_sd
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)

View File

@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
out, pooled = o[:2]
if pooled is not None:
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
else:
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
if len(o) > 2:
extra = {}
for k in o[2]:
v = o[2][k]
if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
extra[k] = v
r = r + (extra,)
@@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def parse_parentheses(string):
result = []
@@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

View File

@@ -1734,21 +1734,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
class RT_DETR_v4(supported_models_base.BASE):
unet_config = {
"image_model": "RT_DETR_v4",
}
supported_inference_dtypes = [torch.float16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.RT_DETR_v4(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]

View File

@@ -224,7 +224,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = True
lm_head: bool = False
stop_tokens = [151643, 151645]
@dataclass
@@ -655,17 +655,6 @@ class Llama2_(nn.Module):
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def get_past_len(self, past_key_values):
return past_key_values[0][2]
def compute_freqs_cis(self, position_ids, device):
return precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None:
x = embeds
@@ -678,12 +667,17 @@ class Llama2_(nn.Module):
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = self.get_past_len(past_key_values)
past_len = past_key_values[0][2]
if position_ids is None:
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
mask = None
if attention_mask is not None:
@@ -818,16 +812,9 @@ class BaseGenerate:
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config
if stop_tokens is None:
stop_tokens = self.model.config.stop_tokens
@@ -842,8 +829,11 @@ class BaseGenerate:
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)
past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
@@ -854,7 +844,7 @@ class BaseGenerate:
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
@@ -866,7 +856,7 @@ class BaseGenerate:
return generated_token_ids
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
@@ -877,11 +867,6 @@ class BaseGenerate:
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if presence_penalty is not None and presence_penalty != 0.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] -= presence_penalty
if temperature != 1.0:
logits = logits / temperature
@@ -912,9 +897,6 @@ class BaseGenerate:
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
if self.model.config.lm_head:
return self.model.lm_head(input)
module = self.model.embed_tokens
offload_stream = None
@@ -1046,19 +1028,12 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
if attention_mask is not None:
# Assign compact sequential positions to attended tokens only,
# skipping over padding so post-padding tokens aren't inflated.
after_mask = attention_mask[0, end:]
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
else:
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]

View File

@@ -64,13 +64,7 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
return [output]
IMAGE_PAD_TOKEN_ID = 151655
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
@@ -78,8 +72,10 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
name="qwen25_7b",
tokenizer=LongCatImageBaseTokenizer,
)
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
skip_template = False
if text.startswith("<|im_start|>"):
skip_template = True
@@ -94,14 +90,11 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
else:
has_images = images is not None and len(images) > 0
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
prefix_ids = base_tok.tokenizer(
template_prefix, add_special_tokens=False
self.longcat_template_prefix, add_special_tokens=False
)["input_ids"]
suffix_ids = base_tok.tokenizer(
self.SUFFIX, add_special_tokens=False
self.longcat_template_suffix, add_special_tokens=False
)["input_ids"]
prompt_tokens = base_tok.tokenize_with_weights(
@@ -113,14 +106,6 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
suffix_pairs = [(t, 1.0) for t in suffix_ids]
combined = prefix_pairs + prompt_pairs + suffix_pairs
if has_images:
embed_count = 0
for i in range(len(combined)):
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
embed_count += 1
tokens = {"qwen25_7b": [combined]}
return tokens

View File

@@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
@@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
return out.to(device=out_device, dtype=torch.float), pooled, extra
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:

View File

@@ -1,833 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
import os
import math
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy import sd1_clip
import comfy.text_encoders.qwen_vl
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
def _qwen35_layer_types(n):
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
@dataclass
class Qwen35Config:
vocab_size: int = 248320
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 24
# Full attention params
num_attention_heads: int = 8
num_key_value_heads: int = 2
head_dim: int = 256
partial_rotary_factor: float = 0.25
# Linear attention (DeltaNet) params
linear_num_key_heads: int = 16
linear_num_value_heads: int = 16
linear_key_head_dim: int = 128
linear_value_head_dim: int = 128
conv_kernel_size: int = 4
# Shared params
max_position_embeddings: int = 32768
rms_norm_eps: float = 1e-6
rope_theta: float = 10000000.0
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
rms_norm_add: bool = True
mlp_activation: str = "silu"
qkv_bias: bool = False
final_norm: bool = True
lm_head: bool = False
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
transformer_type: str = "qwen35_2b"
rope_dims: list = None
rope_scale: float = None
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
QWEN35_MODELS = {
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
}
def _make_config(model_type, config_dict={}):
overrides = QWEN35_MODELS.get(model_type, {}).copy()
overrides.pop("vision", None)
if "num_hidden_layers" in overrides:
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
overrides.update(config_dict)
return Qwen35Config(**overrides)
class RMSNormGated(RMSNorm):
def forward(self, x, gate):
return super().forward(x) * F.silu(gate.to(x.dtype))
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
initial_dtype = query.dtype
query = F.normalize(query, dim=-1)
key = F.normalize(key, dim=-1)
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
# weight: [channels, kernel_size]
state_len = conv_state.shape[-1]
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
conv_state.copy_(combined[:, :, -state_len:])
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
if bias is not None:
out = out + bias.unsqueeze(0).unsqueeze(-1)
return F.silu(out).to(x.dtype)
# GatedDeltaNet - Linear Attention Layer
class GatedDeltaNet(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
hidden = config.hidden_size
self.num_key_heads = config.linear_num_key_heads
self.num_value_heads = config.linear_num_value_heads
self.key_head_dim = config.linear_key_head_dim
self.value_head_dim = config.linear_value_head_dim
self.conv_kernel_size = config.conv_kernel_size
key_dim = self.num_key_heads * self.key_head_dim
value_dim = self.num_value_heads * self.value_head_dim
self.key_dim = key_dim
self.value_dim = value_dim
conv_dim = key_dim * 2 + value_dim
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(self, x, past_key_value=None, **kwargs):
batch_size, seq_len, _ = x.shape
use_recurrent = (
past_key_value is not None
and past_key_value[2] > 0
and seq_len == 1
)
# Projections (shared)
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
z = self.in_proj_z(x)
b = self.in_proj_b(x)
a = self.in_proj_a(x)
# Conv1d
if use_recurrent:
recurrent_state, conv_state, step_index = past_key_value
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
else:
if past_key_value is not None:
recurrent_state, conv_state, step_index = past_key_value
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
# Split QKV and compute beta/g
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
# Delta rule
if use_recurrent:
# single-token path: work in [B, heads, dim] without seq dim
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
if self.num_value_heads != self.num_key_heads:
rep = self.num_value_heads // self.num_key_heads
query = query.repeat_interleave(rep, dim=1)
key = key.repeat_interleave(rep, dim=1)
scale = self.key_head_dim ** -0.5
q = F.normalize(query.float(), dim=-1) * scale
k = F.normalize(key.float(), dim=-1)
v = value.float()
beta_t = beta.reshape(batch_size, -1)
g_t = g.reshape(batch_size, -1).exp()
# In-place state update: [B, heads, k_dim, v_dim]
recurrent_state.mul_(g_t[:, :, None, None])
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
delta = (v - kv_mem) * beta_t[:, :, None]
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
present_key_value = (recurrent_state, conv_state, step_index + 1)
else:
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
if self.num_value_heads != self.num_key_heads:
rep = self.num_value_heads // self.num_key_heads
query = query.repeat_interleave(rep, dim=2)
key = key.repeat_interleave(rep, dim=2)
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
query, key, value, g=g, beta=beta,
initial_state=None,
output_final_state=past_key_value is not None,
)
present_key_value = None
if past_key_value is not None:
if last_recurrent_state is not None:
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
# Gated norm + output projection (shared)
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
return output, present_key_value
# GatedAttention - Full Attention with output gating
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
"""Compute RoPE frequencies for partial rotary embeddings."""
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if mrope_section is not None and position_ids.shape[0] == 3:
mrope_section_2 = [s * 2 for s in mrope_section]
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
"""Apply RoPE to only the first rotary_dim dimensions."""
xq_rot = xq[..., :rotary_dim]
xq_pass = xq[..., rotary_dim:]
xk_rot = xk[..., :rotary_dim]
xk_pass = xk[..., rotary_dim:]
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
xq = torch.cat([xq_rot, xq_pass], dim=-1)
xk = torch.cat([xk_rot, xk_pass], dim=-1)
return xq, xk
class GatedAttention(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.inner_size = self.num_heads * self.head_dim
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
# q_proj outputs 2x: query + gate
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
# QK norms with (1+weight) scaling
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
batch_size, seq_length, _ = x.shape
# Project Q (with gate), K, V
qg = self.q_proj(x)
# Split into query and gate: each is [B, seq, inner_size]
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
xk = self.k_proj(x)
xv = self.v_proj(x)
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply partial RoPE
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
# KV cache
present_key_value = None
if past_key_value is not None:
past_key, past_value, index = past_key_value
num_tokens = xk.shape[2]
if past_key.shape[2] >= (index + num_tokens):
past_key[:, :, index:index + num_tokens] = xk
past_value[:, :, index:index + num_tokens] = xv
xk = past_key[:, :, :index + num_tokens]
xv = past_value[:, :, :index + num_tokens]
present_key_value = (past_key, past_value, index + num_tokens)
else:
if index > 0:
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
# Expand KV heads for GQA
if self.num_heads != self.num_kv_heads:
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
output = output * gate.sigmoid()
return self.o_proj(output), present_key_value
# Hybrid Transformer Block
class Qwen35TransformerBlock(nn.Module):
def __init__(self, config, index, device=None, dtype=None, ops=None):
super().__init__()
self.layer_type = config.layer_types[index]
if self.layer_type == "linear_attention":
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
else:
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
if self.layer_type == "linear_attention":
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
else:
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
x = x + h
x = x + self.mlp(self.post_attention_layernorm(x))
return x, present_key_value
# Qwen35 Transformer Backbone
class Qwen35Transformer(Llama2_):
def __init__(self, config, device=None, dtype=None, ops=None):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def get_past_len(self, past_key_values):
for i, layer in enumerate(self.layers):
if layer.layer_type == "full_attention":
if len(past_key_values) > i:
return past_key_values[i][2]
break
return 0
def compute_freqs_cis(self, position_ids, device):
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
return precompute_partial_rope(
self.config.head_dim, rotary_dim, position_ids,
self.config.rope_theta, device=device,
mrope_section=self.config.mrope_section,
)
# Vision Encoder
class Qwen35VisionPatchEmbed(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.patch_size = config["patch_size"]
self.temporal_patch_size = config["temporal_patch_size"]
self.in_channels = config["in_channels"]
self.embed_dim = config["hidden_size"]
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
class Qwen35VisionMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
def forward(self, hidden_state):
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
class Qwen35VisionRotaryEmbedding(nn.Module):
def __init__(self, dim, theta=10000.0):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen):
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen35VisionAttention(nn.Module):
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
super().__init__()
self.dim = hidden_size
self.num_heads = num_heads
self.head_dim = self.dim // self.num_heads
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
seq_length = x.shape[0]
query_states, key_states, value_states = (
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
# Process per-sequence attention
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_splits = torch.split(query_states, lengths, dim=0)
k_splits = torch.split(key_states, lengths, dim=0)
v_splits = torch.split(value_states, lengths, dim=0)
attn_outputs = []
for q, k, v in zip(q_splits, k_splits, v_splits):
q = q.transpose(0, 1).unsqueeze(0)
k = k.transpose(0, 1).unsqueeze(0)
v = v.transpose(0, 1).unsqueeze(0)
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1)
return self.proj(attn_output)
class Qwen35VisionBlock(nn.Module):
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
return x + self.mlp(self.norm2(x))
class Qwen35VisionPatchMerger(nn.Module):
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
super().__init__()
merge_dim = hidden_size * (spatial_merge_size ** 2)
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
self.merge_dim = merge_dim
def forward(self, x):
x = self.norm(x).view(-1, self.merge_dim)
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
class Qwen35VisionModel(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.spatial_merge_size = config["spatial_merge_size"]
self.patch_size = config["patch_size"]
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.hidden_size = config["hidden_size"]
self.num_heads = config["num_heads"]
self.num_position_embeddings = config["num_position_embeddings"]
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
self.blocks = nn.ModuleList([
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
for _ in range(config["depth"])
])
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
def rot_pos_emb(self, grid_thw):
merge_size = self.spatial_merge_size
grid_thw_list = grid_thw.tolist()
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
freq_table = self.rotary_pos_emb(max_hw)
device = freq_table.device
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw_list:
num_frames, height, width = int(num_frames), int(height), int(width)
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra_row = torch.arange(merge_size, device=device)
intra_col = torch.arange(merge_size, device=device)
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
num_tokens = coords.shape[0]
pos_ids[offset:offset + num_tokens] = coords
offset += num_tokens
embeddings = freq_table[pos_ids]
embeddings = embeddings.flatten(1)
return embeddings
def fast_pos_embed_interpolate(self, grid_thw):
grid_thw_list = grid_thw.tolist()
grid_ts = [int(row[0]) for row in grid_thw_list]
grid_hs = [int(row[1]) for row in grid_thw_list]
grid_ws = [int(row[2]) for row in grid_thw_list]
device = self.pos_embed.weight.device
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
for t, h, w in grid_thw_list:
h, w = int(h), int(w)
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for j in range(4):
idx_list[j].extend(indices[j].tolist())
weight_list[j].extend(weights[j].tolist())
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
patch_pos_embeds_permute = []
merge_size = self.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
return torch.cat(patch_pos_embeds_permute)
def forward(self, x, grid_thw):
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
x = x + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len = x.shape[0]
x = x.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos().unsqueeze(-2)
sin = emb.sin().unsqueeze(-2)
sin_half = sin.shape[-1] // 2
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
merged = self.merger(x)
return merged
# Model Wrapper
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
model_type = "qwen35_2b"
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = _make_config(self.model_type, config_dict)
self.num_layers = config.num_hidden_layers
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
return self.visual(image.to(device, dtype=torch.float32), grid), grid
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
grid = None
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
if grid is None:
position_ids = None
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
model_config = self.model.config
past_key_values = []
for i in range(model_config.num_hidden_layers):
if model_config.layer_types[i] == "linear_attention":
recurrent_state = torch.zeros(
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
device=device, dtype=torch.float32
)
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
conv_state = torch.zeros(
[batch, conv_dim, model_config.conv_kernel_size - 1],
device=device, dtype=execution_dtype
)
past_key_values.append((recurrent_state, conv_state, 0))
else:
past_key_values.append((
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
0
))
return past_key_values
# Tokenizer and Text Encoder Wrappers
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
from transformers import Qwen2Tokenizer
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
image = kwargs.get("image", None)
if image is not None and len(images) == 0:
images = [image]
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if prevent_empty_text and text == '':
text = ' '
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
if not thinking:
llama_text += "<think>\n</think>\n"
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == 248056: # <|image_pad|>
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens
class Qwen35ClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
class Qwen35_(Qwen35):
pass
Qwen35_.model_type = model_type
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Qwen35TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
def tokenizer(model_type="qwen35_2b"):
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
return Qwen35ImageTokenizer_
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
class Qwen35TEModel_(Qwen35TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
return Qwen35TEModel_

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -425,7 +425,4 @@ class Qwen2VLVisionTransformer(nn.Module):
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
hidden_states = self.merger(hidden_states)
# Potentially important for spatially precise edits. This is present in the HF implementation.
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states

View File

@@ -20,8 +20,6 @@
import torch
import math
import struct
import ctypes
import os
import comfy.memory_management
import safetensors.torch
import numpy as np
@@ -34,7 +32,7 @@ from einops import rearrange
from comfy.cli_args import args
import json
import time
import threading
import mmap
import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
@@ -83,17 +81,14 @@ _TYPES = {
}
def load_safetensors(ckpt):
import comfy_aimdo.model_mmap
f = open(ckpt, "rb")
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
f = open(ckpt, "rb", buffering=0)
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
header_size = struct.unpack("<Q", mapping[:8])[0]
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
header_size = struct.unpack("<Q", mv[:8])[0]
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
mv = mv[(data_base_offset := 8 + header_size):]
mv = mv[8 + header_size:]
sd = {}
for name, info in header.items():
@@ -107,14 +102,7 @@ def load_safetensors(ckpt):
with warnings.catch_warnings():
#We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable")
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
return sd, header.get("__metadata__", {}),
@@ -897,10 +885,6 @@ def set_attr(obj, attr, value):
return prev
def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):
@@ -1135,8 +1119,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pbar.update(1)
continue
out = output[b:b+1].zero_()
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
@@ -1151,7 +1135,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
mask = torch.ones_like(ps)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
@@ -1174,7 +1158,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if pbar is not None:
pbar.update(1)
out.div_(out_div)
output[b:b+1] = out/out_div
return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):

View File

@@ -15,7 +15,6 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
"assets": args.enable_assets,
}

View File

@@ -5,10 +5,6 @@ from comfy_api.latest._input import (
MaskInput,
LatentInput,
VideoInput,
CurvePoint,
CurveInput,
MonotoneCubicCurve,
LinearCurve,
)
__all__ = [
@@ -17,8 +13,4 @@ __all__ = [
"MaskInput",
"LatentInput",
"VideoInput",
"CurvePoint",
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
]

View File

@@ -25,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
@@ -85,36 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs
across ComfyUI instances.
Example::
from comfy_api.latest import Caching
class MyCacheProvider(Caching.CacheProvider):
async def on_lookup(self, context):
... # check external storage
async def on_store(self, context, value):
... # store to external storage
Caching.register_provider(MyCacheProvider())
"""
from ._caching import CacheProvider, CacheContext, CacheValue
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Register an external cache provider. Providers are called in registration order."""
from comfy_execution.cache_provider import register_cache_provider
register_cache_provider(provider)
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Unregister a previously registered cache provider."""
from comfy_execution.cache_provider import unregister_cache_provider
unregister_cache_provider(provider)
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@@ -147,9 +116,6 @@ class Types:
VOXEL = VOXEL
File3D = File3D
Caching = ComfyAPI_latest.Caching
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@@ -169,7 +135,6 @@ __all__ = [
"Input",
"InputImpl",
"Types",
"Caching",
"ComfyExtension",
"io",
"IO",

View File

@@ -1,42 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
@dataclass
class CacheContext:
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass

View File

@@ -1,5 +1,4 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
from .video_types import VideoInput
__all__ = [
@@ -8,8 +7,4 @@ __all__ = [
"VideoInput",
"MaskInput",
"LatentInput",
"CurvePoint",
"CurveInput",
"MonotoneCubicCurve",
"LinearCurve",
]

View File

@@ -1,219 +0,0 @@
from __future__ import annotations
import logging
import math
from abc import ABC, abstractmethod
import numpy as np
logger = logging.getLogger(__name__)
CurvePoint = tuple[float, float]
class CurveInput(ABC):
"""Abstract base class for curve inputs.
Subclasses represent different curve representations (control-point
interpolation, analytical functions, LUT-based, etc.) while exposing a
uniform evaluation interface to downstream nodes.
"""
@property
@abstractmethod
def points(self) -> list[CurvePoint]:
"""The control points that define this curve."""
@abstractmethod
def interp(self, x: float) -> float:
"""Evaluate the curve at a single *x* value in [0, 1]."""
def interp_array(self, xs: np.ndarray) -> np.ndarray:
"""Vectorised evaluation over a numpy array of x values.
Subclasses should override this for better performance. The default
falls back to scalar ``interp`` calls.
"""
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
def to_lut(self, size: int = 256) -> np.ndarray:
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
return self.interp_array(np.linspace(0.0, 1.0, size))
@staticmethod
def from_raw(data) -> CurveInput:
"""Convert raw curve data (dict or point list) to a CurveInput instance.
Accepts:
- A ``CurveInput`` instance (returned as-is).
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
"""
if isinstance(data, CurveInput):
return data
if isinstance(data, dict):
raw_points = data["points"]
interpolation = data.get("interpolation", "monotone_cubic")
else:
raw_points = data
interpolation = "monotone_cubic"
points = [(float(x), float(y)) for x, y in raw_points]
if interpolation == "linear":
return LinearCurve(points)
if interpolation != "monotone_cubic":
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
return MonotoneCubicCurve(points)
class MonotoneCubicCurve(CurveInput):
"""Monotone cubic Hermite interpolation over control points.
Mirrors the frontend ``createMonotoneInterpolator`` in
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
backend evaluation matches the editor preview exactly.
All heavy work (sorting, slope computation) happens once at construction.
``interp_array`` is fully vectorised with numpy.
"""
def __init__(self, control_points: list[CurvePoint]):
sorted_pts = sorted(control_points, key=lambda p: p[0])
self._points = [(float(x), float(y)) for x, y in sorted_pts]
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
self._slopes = self._compute_slopes()
@property
def points(self) -> list[CurvePoint]:
return list(self._points)
def _compute_slopes(self) -> np.ndarray:
xs, ys = self._xs, self._ys
n = len(xs)
if n < 2:
return np.zeros(n, dtype=np.float64)
dx = np.diff(xs)
dy = np.diff(ys)
dx_safe = np.where(dx == 0, 1.0, dx)
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
slopes = np.empty(n, dtype=np.float64)
slopes[0] = deltas[0]
slopes[-1] = deltas[-1]
for i in range(1, n - 1):
if deltas[i - 1] * deltas[i] <= 0:
slopes[i] = 0.0
else:
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
for i in range(n - 1):
if deltas[i] == 0:
slopes[i] = 0.0
slopes[i + 1] = 0.0
else:
alpha = slopes[i] / deltas[i]
beta = slopes[i + 1] / deltas[i]
s = alpha * alpha + beta * beta
if s > 9:
t = 3 / math.sqrt(s)
slopes[i] = t * alpha * deltas[i]
slopes[i + 1] = t * beta * deltas[i]
return slopes
def interp(self, x: float) -> float:
xs, ys, slopes = self._xs, self._ys, self._slopes
n = len(xs)
if n == 0:
return 0.0
if n == 1:
return float(ys[0])
if x <= xs[0]:
return float(ys[0])
if x >= xs[-1]:
return float(ys[-1])
hi = int(np.searchsorted(xs, x, side='right'))
hi = min(hi, n - 1)
lo = hi - 1
dx = xs[hi] - xs[lo]
if dx == 0:
return float(ys[lo])
t = (x - xs[lo]) / dx
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
"""Fully vectorised evaluation using numpy."""
xs, ys, slopes = self._xs, self._ys, self._slopes
n = len(xs)
if n == 0:
return np.zeros_like(xs_in, dtype=np.float64)
if n == 1:
return np.full_like(xs_in, ys[0], dtype=np.float64)
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
lo = hi - 1
dx = xs[hi] - xs[lo]
dx_safe = np.where(dx == 0, 1.0, dx)
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
t2 = t * t
t3 = t2 * t
h00 = 2 * t3 - 3 * t2 + 1
h10 = t3 - 2 * t2 + t
h01 = -2 * t3 + 3 * t2
h11 = t3 - t2
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
result = np.where(xs_in <= xs[0], ys[0], result)
result = np.where(xs_in >= xs[-1], ys[-1], result)
return result
def __repr__(self) -> str:
return f"MonotoneCubicCurve(points={self._points})"
class LinearCurve(CurveInput):
"""Piecewise linear interpolation over control points.
Mirrors the frontend ``createLinearInterpolator`` in
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
"""
def __init__(self, control_points: list[CurvePoint]):
sorted_pts = sorted(control_points, key=lambda p: p[0])
self._points = [(float(x), float(y)) for x, y in sorted_pts]
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
@property
def points(self) -> list[CurvePoint]:
return list(self._points)
def interp(self, x: float) -> float:
xs, ys = self._xs, self._ys
n = len(xs)
if n == 0:
return 0.0
if n == 1:
return float(ys[0])
return float(np.interp(x, xs, ys))
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
if len(self._xs) == 0:
return np.zeros_like(xs_in, dtype=np.float64)
if len(self._xs) == 1:
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
return np.interp(xs_in, self._xs, self._ys)
def __repr__(self) -> str:
return f"LinearCurve(points={self._points})"

View File

@@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
to_skip = int(offset_seconds * audio_stream.sample_rate)
if to_skip < frame.samples:
has_first_frame = True
break
@@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if self.__duration and frame.time > start_time + self.__duration:
if frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
from comfy.samplers import CFGGuider, Sampler
from comfy.sd import CLIP, VAE
from comfy.sd import StyleModel as StyleModel_
from comfy_api.input import VideoInput, CurveInput as CurveInput_
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -1242,9 +1242,8 @@ class BoundingBox(ComfyTypeIO):
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
from comfy_api.input import CurvePoint
if TYPE_CHECKING:
Type = CurveInput_
CurvePoint = tuple[float, float]
Type = list[CurvePoint]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
@@ -1253,18 +1252,6 @@ class Curve(ComfyTypeIO):
if default is None:
self.default = [(0.0, 0.0), (1.0, 1.0)]
def as_dict(self):
d = super().as_dict()
if self.default is not None:
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
return d
@comfytype(io_type="HISTOGRAM")
class Histogram(ComfyTypeIO):
"""A histogram represented as a list of bin counts."""
Type = list[int]
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
@@ -1373,7 +1360,6 @@ class NodeInfoV1:
price_badge: dict | None = None
search_aliases: list[str]=None
essentials_category: str=None
has_intermediate_output: bool=None
@dataclass
@@ -1497,16 +1483,6 @@ class Schema:
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
essentials_category: str | None = None
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
has_intermediate_output: bool=False
"""Flags this node as having intermediate output that should persist across page refreshes.
Nodes with this flag behave like output nodes (their UI results are cached and resent
to the frontend) but do NOT automatically get added to the execution list. This means
they will only execute if they are on the dependency path of a real output node.
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
"""
def validate(self):
'''Validate the schema:
@@ -1606,7 +1582,6 @@ class Schema:
category=self.category,
description=self.description,
output_node=self.is_output_node,
has_intermediate_output=self.has_intermediate_output,
deprecated=self.is_deprecated,
experimental=self.is_experimental,
dev_only=self.is_dev_only,
@@ -1898,14 +1873,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA()
return cls._OUTPUT_NODE
_HAS_INTERMEDIATE_OUTPUT = None
@final
@classproperty
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls.GET_SCHEMA()
return cls._HAS_INTERMEDIATE_OUTPUT
_INPUT_IS_LIST = None
@final
@classproperty
@@ -1998,8 +1965,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None:
cls._OUTPUT_NODE = schema.is_output_node
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
if cls._INPUT_IS_LIST is None:
cls._INPUT_IS_LIST = schema.is_input_list
if cls._NOT_IDEMPOTENT is None:
@@ -2275,6 +2240,5 @@ __all__ = [
"PriceBadge",
"BoundingBox",
"Curve",
"Histogram",
"NodeReplace",
]

View File

@@ -67,7 +67,6 @@ class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None)
thought: bool | None = Field(None)
class GeminiTextPart(BaseModel):

View File

@@ -7,8 +7,7 @@ class ImageGenerationRequest(BaseModel):
aspect_ratio: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_format: str = Field("url")
resolution: str = Field(...)
response_for: str = Field("url")
class InputUrlObject(BaseModel):
@@ -17,33 +16,24 @@ class InputUrlObject(BaseModel):
class ImageEditRequest(BaseModel):
model: str = Field(...)
images: list[InputUrlObject] = Field(...)
image: InputUrlObject = Field(...)
prompt: str = Field(...)
resolution: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_format: str = Field("url")
aspect_ratio: str | None = Field(...)
response_for: str = Field("url")
class VideoGenerationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
image: InputUrlObject | None = Field(None)
reference_images: list[InputUrlObject] | None = Field(None)
image: InputUrlObject | None = Field(...)
duration: int = Field(...)
aspect_ratio: str | None = Field(...)
resolution: str = Field(...)
seed: int = Field(...)
class VideoExtensionRequest(BaseModel):
prompt: str = Field(...)
video: InputUrlObject = Field(...)
duration: int = Field(default=6)
model: str | None = Field(default=None)
class VideoEditRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
@@ -57,13 +47,8 @@ class ImageResponseObject(BaseModel):
revised_prompt: str | None = Field(None)
class UsageObject(BaseModel):
cost_in_usd_ticks: int | None = Field(None)
class ImageGenerationResponse(BaseModel):
data: list[ImageResponseObject] = Field(...)
usage: UsageObject | None = Field(None)
class VideoGenerationResponse(BaseModel):
@@ -80,4 +65,3 @@ class VideoStatusResponse(BaseModel):
status: str | None = Field(None)
video: VideoResponseObject | None = Field(None)
model: str | None = Field(None)
usage: UsageObject | None = Field(None)

View File

@@ -66,17 +66,13 @@ class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)
class TaskFile3DInput(BaseModel):
class To3DUVFileInput(BaseModel):
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
Url: str = Field(...)
class To3DUVTaskRequest(BaseModel):
File: TaskFile3DInput = Field(...)
class To3DPartTaskRequest(BaseModel):
File: TaskFile3DInput = Field(...)
File: To3DUVFileInput = Field(...)
class TextureEditImageInfo(BaseModel):
@@ -84,13 +80,7 @@ class TextureEditImageInfo(BaseModel):
class TextureEditTaskRequest(BaseModel):
File3D: TaskFile3DInput = Field(...)
File3D: To3DUVFileInput = Field(...)
Image: TextureEditImageInfo | None = Field(None)
Prompt: str | None = Field(None)
EnablePBR: bool | None = Field(None)
class SmartTopologyRequest(BaseModel):
File3D: TaskFile3DInput = Field(...)
PolygonType: str | None = Field(...)
FaceLevel: str | None = Field(...)

View File

@@ -148,4 +148,3 @@ class MotionControlRequest(BaseModel):
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...)

View File

@@ -1,43 +0,0 @@
from pydantic import BaseModel, Field
class QuiverImageObject(BaseModel):
url: str = Field(...)
class QuiverTextToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
prompt: str = Field(...)
instructions: str | None = Field(default=None)
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverImageToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
image: QuiverImageObject = Field(...)
auto_crop: bool | None = Field(default=None)
target_size: int | None = Field(default=None, ge=128, le=4096)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverSVGResponseItem(BaseModel):
svg: str = Field(...)
mime_type: str | None = Field(default="image/svg+xml")
class QuiverSVGUsage(BaseModel):
total_tokens: int | None = Field(default=None)
input_tokens: int | None = Field(default=None)
output_tokens: int | None = Field(default=None)
class QuiverSVGResponse(BaseModel):
id: str | None = Field(default=None)
created: int | None = Field(default=None)
data: list[QuiverSVGResponseItem] = Field(...)
usage: QuiverSVGUsage | None = Field(default=None)

View File

@@ -1,68 +0,0 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@@ -47,10 +47,6 @@ SEEDREAM_MODELS = {
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
logger = logging.getLogger(__name__)
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
@@ -139,7 +135,6 @@ class ByteDanceImageNode(IO.ComfyNode):
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
is_deprecated=True,
)
@classmethod
@@ -947,7 +942,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
]
return await process_video_task(
cls,
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
payload=Image2VideoTaskCreationRequest(model=model, content=x),
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
)
@@ -957,12 +952,6 @@ async def process_video_task(
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
) -> IO.NodeOutput:
if payload.model in DEPRECATED_MODELS:
logger.warning(
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
payload.model,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),

Some files were not shown because too many files have changed in this diff Show More