mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-26 07:27:34 +00:00
Compare commits
56 Commits
toolkit/wi
...
feature/cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a145651cc0 | ||
|
|
b53b10ea61 | ||
|
|
7d5534d8e5 | ||
|
|
5ebb0c2e0b | ||
|
|
a0a64c679f | ||
|
|
8e73678dae | ||
|
|
c2862b24af | ||
|
|
f9ec85f739 | ||
|
|
2d5fd3f5dd | ||
|
|
2d4970ff67 | ||
|
|
e87858e974 | ||
|
|
da6edb5a4e | ||
|
|
6265a239f3 | ||
|
|
d49420b3c7 | ||
|
|
ebf6b52e32 | ||
|
|
25b6d1d629 | ||
|
|
11c15d8832 | ||
|
|
b5d32e6ad2 | ||
|
|
a11f68dd3b | ||
|
|
dc719cde9c | ||
|
|
87cda1fc25 | ||
|
|
45d5c83a30 | ||
|
|
c646d211be | ||
|
|
589228e671 | ||
|
|
e4455fd43a | ||
|
|
f49856af57 | ||
|
|
82b868a45a | ||
|
|
8458ae2686 | ||
|
|
fd0261d2bc | ||
|
|
ab14541ef7 | ||
|
|
6589562ae3 | ||
|
|
fabed694a2 | ||
|
|
f6b869d7d3 | ||
|
|
56ff88f951 | ||
|
|
9fff091f35 | ||
|
|
dcd659590f | ||
|
|
b67ed2a45f | ||
|
|
06957022d4 | ||
|
|
b941913f1d | ||
|
|
cad24ce262 | ||
|
|
68d542cc06 | ||
|
|
735a0465e5 | ||
|
|
8b9d039f26 | ||
|
|
035414ede4 | ||
|
|
1a157e1f97 | ||
|
|
ed7c2c6579 | ||
|
|
379fbd1a82 | ||
|
|
8cc746a864 | ||
|
|
9a870b5102 | ||
|
|
ca17fc8355 | ||
|
|
20561aa919 | ||
|
|
7a16e8aa4e | ||
|
|
b202f842af | ||
|
|
7d5f5252c3 | ||
|
|
2bd4d82b4f | ||
|
|
593be209a4 |
103
.github/scripts/check-ai-co-authors.sh
vendored
Executable file
103
.github/scripts/check-ai-co-authors.sh
vendored
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env bash
|
||||
# Checks pull request commits for AI agent Co-authored-by trailers.
|
||||
# Exits non-zero when any are found and prints fix instructions.
|
||||
set -euo pipefail
|
||||
|
||||
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||
|
||||
# Known AI coding-agent trailer patterns (case-insensitive).
|
||||
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
|
||||
AGENT_PATTERNS=(
|
||||
# Anthropic — Claude Code / Amp
|
||||
'noreply@anthropic\.com'
|
||||
# Cursor
|
||||
'cursoragent@cursor\.com'
|
||||
# GitHub Copilot
|
||||
'copilot-swe-agent\[bot\]'
|
||||
'copilot@github\.com'
|
||||
# OpenAI Codex
|
||||
'noreply@openai\.com'
|
||||
'codex@openai\.com'
|
||||
# Aider
|
||||
'aider@aider\.chat'
|
||||
# Google — Gemini / Jules
|
||||
'gemini@google\.com'
|
||||
'jules@google\.com'
|
||||
# Windsurf / Codeium
|
||||
'@codeium\.com'
|
||||
# Devin
|
||||
'devin-ai-integration\[bot\]'
|
||||
'devin@cognition\.ai'
|
||||
'devin@cognition-labs\.com'
|
||||
# Amazon Q Developer
|
||||
'amazon-q-developer'
|
||||
'@amazon\.com.*[Qq].[Dd]eveloper'
|
||||
# Cline
|
||||
'cline-bot'
|
||||
'cline@cline\.ai'
|
||||
# Continue
|
||||
'continue-agent'
|
||||
'continue@continue\.dev'
|
||||
# Sourcegraph
|
||||
'noreply@sourcegraph\.com'
|
||||
# Generic catch-alls for common agent name patterns
|
||||
'Co-authored-by:.*\b[Cc]laude\b'
|
||||
'Co-authored-by:.*\b[Cc]opilot\b'
|
||||
'Co-authored-by:.*\b[Cc]ursor\b'
|
||||
'Co-authored-by:.*\b[Cc]odex\b'
|
||||
'Co-authored-by:.*\b[Gg]emini\b'
|
||||
'Co-authored-by:.*\b[Aa]ider\b'
|
||||
'Co-authored-by:.*\b[Dd]evin\b'
|
||||
'Co-authored-by:.*\b[Ww]indsurf\b'
|
||||
'Co-authored-by:.*\b[Cc]line\b'
|
||||
'Co-authored-by:.*\b[Aa]mazon Q\b'
|
||||
'Co-authored-by:.*\b[Jj]ules\b'
|
||||
'Co-authored-by:.*\bOpenCode\b'
|
||||
)
|
||||
|
||||
# Build a single alternation regex from all patterns.
|
||||
regex=""
|
||||
for pattern in "${AGENT_PATTERNS[@]}"; do
|
||||
if [[ -n "$regex" ]]; then
|
||||
regex="${regex}|${pattern}"
|
||||
else
|
||||
regex="$pattern"
|
||||
fi
|
||||
done
|
||||
|
||||
# Collect Co-authored-by lines from every commit in the PR range.
|
||||
violations=""
|
||||
while IFS= read -r sha; do
|
||||
message="$(git log -1 --format='%B' "$sha")"
|
||||
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
|
||||
if [[ -z "$matched_lines" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
while IFS= read -r line; do
|
||||
if echo "$line" | grep -iqE "$regex"; then
|
||||
short="$(git log -1 --format='%h' "$sha")"
|
||||
violations="${violations} ${short}: ${line}"$'\n'
|
||||
fi
|
||||
done <<< "$matched_lines"
|
||||
done < <(git rev-list "${base_sha}..${head_sha}")
|
||||
|
||||
if [[ -n "$violations" ]]; then
|
||||
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
|
||||
echo ""
|
||||
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
|
||||
echo ""
|
||||
echo "$violations"
|
||||
echo "These trailers should be removed before merging."
|
||||
echo ""
|
||||
echo "To fix, rewrite the commit messages with:"
|
||||
echo " git rebase -i ${base_sha}"
|
||||
echo ""
|
||||
echo "and remove the Co-authored-by lines, then force-push your branch."
|
||||
echo ""
|
||||
echo "If you believe this is a false positive, please open an issue."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No AI agent Co-authored-by trailers found."
|
||||
19
.github/workflows/check-ai-co-authors.yml
vendored
Normal file
19
.github/workflows/check-ai-co-authors.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Check AI Co-Authors
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*']
|
||||
|
||||
jobs:
|
||||
check-ai-co-authors:
|
||||
name: Check for AI agent co-author trailers
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check commits for AI co-author trailers
|
||||
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
|
||||
@@ -8,7 +8,7 @@ from alembic import context
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base
|
||||
from app.database.models import Base, NAMING_CONVENTION
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
@@ -51,7 +51,10 @@ def run_migrations_online() -> None:
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
naming_convention=NAMING_CONVENTION,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
98
alembic_db/versions/0003_add_metadata_job_id.py
Normal file
98
alembic_db/versions/0003_add_metadata_job_id.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Add system_metadata and job_id columns to asset_references.
|
||||
Change preview_id FK from assets.id to asset_references.id.
|
||||
|
||||
Revision ID: 0003_add_metadata_job_id
|
||||
Revises: 0002_merge_to_asset_references
|
||||
Create Date: 2026-03-09
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from app.database.models import NAMING_CONVENTION
|
||||
|
||||
revision = "0003_add_metadata_job_id"
|
||||
down_revision = "0002_merge_to_asset_references"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("system_metadata", sa.JSON(), nullable=True)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column("job_id", sa.String(length=36), nullable=True)
|
||||
)
|
||||
|
||||
# Change preview_id FK from assets.id to asset_references.id (self-ref).
|
||||
# Existing values are asset-content IDs that won't match reference IDs,
|
||||
# so null them out first.
|
||||
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
|
||||
with op.batch_alter_table(
|
||||
"asset_references", naming_convention=NAMING_CONVENTION
|
||||
) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
"fk_asset_references_preview_id_assets", type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_asset_references_preview_id_asset_references",
|
||||
"asset_references",
|
||||
["preview_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
batch_op.create_index(
|
||||
"ix_asset_references_preview_id", ["preview_id"]
|
||||
)
|
||||
|
||||
# Purge any all-null meta rows before adding the constraint
|
||||
op.execute(
|
||||
"DELETE FROM asset_reference_meta"
|
||||
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
|
||||
)
|
||||
with op.batch_alter_table("asset_reference_meta") as batch_op:
|
||||
batch_op.create_check_constraint(
|
||||
"ck_asset_reference_meta_has_value",
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# SQLite doesn't reflect CHECK constraints, so we must declare it
|
||||
# explicitly via table_args for the batch recreate to find it.
|
||||
# Use the fully-rendered constraint name to avoid the naming convention
|
||||
# doubling the prefix.
|
||||
with op.batch_alter_table(
|
||||
"asset_reference_meta",
|
||||
table_args=[
|
||||
sa.CheckConstraint(
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
name="ck_asset_reference_meta_has_value",
|
||||
),
|
||||
],
|
||||
) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
"ck_asset_reference_meta_has_value", type_="check"
|
||||
)
|
||||
|
||||
with op.batch_alter_table(
|
||||
"asset_references", naming_convention=NAMING_CONVENTION
|
||||
) as batch_op:
|
||||
batch_op.drop_index("ix_asset_references_preview_id")
|
||||
batch_op.drop_constraint(
|
||||
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_asset_references_preview_id_assets",
|
||||
"assets",
|
||||
["preview_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.drop_column("job_id")
|
||||
batch_op.drop_column("system_metadata")
|
||||
@@ -13,6 +13,7 @@ from pydantic import ValidationError
|
||||
import folder_paths
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in, schemas_out
|
||||
from app.assets.services import schemas
|
||||
from app.assets.api.schemas_in import (
|
||||
AssetValidationError,
|
||||
UploadError,
|
||||
@@ -38,6 +39,7 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@@ -122,6 +124,61 @@ def _validate_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
|
||||
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
|
||||
if not user_metadata:
|
||||
return None
|
||||
filename = user_metadata.get("filename")
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
if "input" in tags:
|
||||
view_type = "input"
|
||||
elif "output" in tags:
|
||||
view_type = "output"
|
||||
else:
|
||||
return None
|
||||
|
||||
subfolder = ""
|
||||
if "/" in filename:
|
||||
subfolder, filename = filename.rsplit("/", 1)
|
||||
|
||||
encoded_filename = urllib.parse.quote(filename, safe="")
|
||||
url = f"/api/view?type={view_type}&filename={encoded_filename}"
|
||||
if subfolder:
|
||||
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
|
||||
return url
|
||||
|
||||
|
||||
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
|
||||
"""Build an Asset response from a service result."""
|
||||
if result.ref.preview_id:
|
||||
preview_detail = get_asset_detail(result.ref.preview_id)
|
||||
if preview_detail:
|
||||
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
|
||||
else:
|
||||
preview_url = None
|
||||
else:
|
||||
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||
return schemas_out.Asset(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
preview_url=preview_url,
|
||||
preview_id=result.ref.preview_id,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
metadata=result.ref.system_metadata,
|
||||
job_id=result.ref.job_id,
|
||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||
created_at=result.ref.created_at,
|
||||
updated_at=result.ref.updated_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
@_require_assets_feature_enabled
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
@@ -164,20 +221,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries = [
|
||||
schemas_out.AssetSummary(
|
||||
id=item.ref.id,
|
||||
name=item.ref.name,
|
||||
asset_hash=item.asset.hash if item.asset else None,
|
||||
size=int(item.asset.size_bytes) if item.asset else None,
|
||||
mime_type=item.asset.mime_type if item.asset else None,
|
||||
tags=item.tags,
|
||||
created_at=item.ref.created_at,
|
||||
updated_at=item.ref.updated_at,
|
||||
last_access_time=item.ref.last_access_time,
|
||||
)
|
||||
for item in result.items
|
||||
]
|
||||
summaries = [_build_asset_response(item) for item in result.items]
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
@@ -207,18 +251,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
||||
{"id": reference_id},
|
||||
)
|
||||
|
||||
payload = schemas_out.AssetDetail(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
)
|
||||
payload = _build_asset_response(result)
|
||||
except ValueError as e:
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
||||
@@ -230,7 +263,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
@@ -312,32 +345,31 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||
)
|
||||
|
||||
# Derive name from hash if not provided
|
||||
name = body.name
|
||||
if name is None:
|
||||
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
|
||||
|
||||
result = create_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
name=name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
mime_type=body.mime_type,
|
||||
preview_id=body.preview_id,
|
||||
)
|
||||
if result is None:
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
||||
)
|
||||
|
||||
asset = _build_asset_response(result)
|
||||
payload_out = schemas_out.AssetCreated(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash,
|
||||
size=int(result.asset.size_bytes),
|
||||
mime_type=result.asset.mime_type,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
**asset.model_dump(),
|
||||
created_new=result.created_new,
|
||||
)
|
||||
return web.json_response(payload_out.model_dump(mode="json"), status=201)
|
||||
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
@@ -358,6 +390,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
"name": parsed.provided_name,
|
||||
"user_metadata": parsed.user_metadata_raw,
|
||||
"hash": parsed.provided_hash,
|
||||
"mime_type": parsed.provided_mime_type,
|
||||
"preview_id": parsed.provided_preview_id,
|
||||
}
|
||||
)
|
||||
except ValidationError as ve:
|
||||
@@ -386,6 +420,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
mime_type=spec.mime_type,
|
||||
preview_id=spec.preview_id,
|
||||
)
|
||||
if result is None:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
@@ -410,6 +446,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
client_filename=parsed.file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_hash=spec.hash,
|
||||
mime_type=spec.mime_type,
|
||||
preview_id=spec.preview_id,
|
||||
)
|
||||
except AssetValidationError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
@@ -428,21 +466,13 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
payload = schemas_out.AssetCreated(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash,
|
||||
size=int(result.asset.size_bytes),
|
||||
mime_type=result.asset.mime_type,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
asset = _build_asset_response(result)
|
||||
payload_out = schemas_out.AssetCreated(
|
||||
**asset.model_dump(),
|
||||
created_new=result.created_new,
|
||||
)
|
||||
status = 201 if result.created_new else 200
|
||||
return web.json_response(payload.model_dump(mode="json"), status=status)
|
||||
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -464,15 +494,9 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
preview_id=body.preview_id,
|
||||
)
|
||||
payload = schemas_out.AssetUpdated(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
updated_at=result.ref.updated_at,
|
||||
)
|
||||
payload = _build_asset_response(result)
|
||||
except PermissionError as pe:
|
||||
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
||||
except ValueError as ve:
|
||||
@@ -486,7 +510,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -555,7 +579,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
payload = schemas_out.TagsList(
|
||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
@@ -603,7 +627,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
@@ -650,7 +674,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/tags/refine")
|
||||
@_require_assets_feature_enabled
|
||||
async def get_tags_refine(request: web.Request) -> web.Response:
|
||||
"""GET request to get tag histogram for filtered assets."""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _build_validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
tag_counts = list_tag_histogram(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
)
|
||||
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
|
||||
@@ -45,6 +45,8 @@ class ParsedUpload:
|
||||
user_metadata_raw: str | None
|
||||
provided_hash: str | None
|
||||
provided_hash_exists: bool | None
|
||||
provided_mime_type: str | None = None
|
||||
provided_preview_id: str | None = None
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
@@ -98,11 +100,17 @@ class ListAssetsQuery(BaseModel):
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_at_least_one_field(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
if all(
|
||||
v is None
|
||||
for v in (self.name, self.user_metadata, self.preview_id)
|
||||
):
|
||||
raise ValueError(
|
||||
"Provide at least one of: name, user_metadata, preview_id."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -110,9 +118,11 @@ class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
name: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
mime_type: str | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
@@ -138,6 +148,44 @@ class CreateFromHashBody(BaseModel):
|
||||
return []
|
||||
|
||||
|
||||
class TagsRefineQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
limit: conint(ge=1, le=1000) = 100
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@@ -186,21 +234,25 @@ class TagsRemove(TagsAdd):
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
- tags: optional list; if provided, first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||
- mime_type: optional MIME type override
|
||||
- preview_id: optional asset_reference ID for preview
|
||||
|
||||
Files are stored using the content hash as filename stem.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
mime_type: str | None = Field(default=None)
|
||||
preview_id: str | None = Field(default=None) # references an asset_reference id
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
@@ -279,7 +331,7 @@ class UploadAssetSpec(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
raise ValueError("at least one tag is required for uploads")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
|
||||
@@ -4,7 +4,10 @@ from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
class Asset(BaseModel):
|
||||
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
|
||||
``id`` here is the AssetReference id, not the content-addressed Asset id."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
@@ -12,8 +15,14 @@ class AssetSummary(BaseModel):
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
is_immutable: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
prompt_id: str | None = None # deprecated: use job_id
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
@@ -23,50 +32,16 @@ class AssetSummary(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(Asset):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
assets: list[Asset]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _serialize_updated_at(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@@ -91,3 +66,7 @@ class TagsRemove(BaseModel):
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagHistogram(BaseModel):
|
||||
tag_counts: dict[str, int]
|
||||
|
||||
@@ -52,6 +52,8 @@ async def parse_multipart_upload(
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
provided_mime_type: str | None = None
|
||||
provided_preview_id: str | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
@@ -128,6 +130,16 @@ async def parse_multipart_upload(
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
elif fname == "id":
|
||||
raise UploadError(
|
||||
400,
|
||||
"UNSUPPORTED_FIELD",
|
||||
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
|
||||
)
|
||||
elif fname == "mime_type":
|
||||
provided_mime_type = ((await field.text()) or "").strip() or None
|
||||
elif fname == "preview_id":
|
||||
provided_preview_id = ((await field.text()) or "").strip() or None
|
||||
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
raise UploadError(
|
||||
@@ -152,6 +164,8 @@ async def parse_multipart_upload(
|
||||
user_metadata_raw=user_metadata_raw,
|
||||
provided_hash=provided_hash,
|
||||
provided_hash_exists=provided_hash_exists,
|
||||
provided_mime_type=provided_mime_type,
|
||||
provided_preview_id=provided_preview_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -45,13 +45,7 @@ class Asset(Base):
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list[AssetReference]] = relationship(
|
||||
"AssetReference",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
|
||||
foreign_keys=lambda: [AssetReference.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
# preview_id on AssetReference is a self-referential FK to asset_references.id
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
@@ -91,11 +85,15 @@ class AssetReference(Base):
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
||||
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
|
||||
)
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True)
|
||||
)
|
||||
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True), nullable=True, default=None
|
||||
)
|
||||
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
@@ -115,10 +113,10 @@ class AssetReference(Base):
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
preview_ref: Mapped[AssetReference | None] = relationship(
|
||||
"AssetReference",
|
||||
foreign_keys=[preview_id],
|
||||
remote_side=lambda: [AssetReference.id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||
@@ -152,6 +150,7 @@ class AssetReference(Base):
|
||||
Index("ix_asset_references_created_at", "created_at"),
|
||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||
Index("ix_asset_references_preview_id", "preview_id"),
|
||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||
CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
@@ -192,6 +191,10 @@ class AssetReferenceMeta(Base):
|
||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||
CheckConstraint(
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
name="has_value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
create_stub_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
reassign_asset_references,
|
||||
@@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
||||
UnenrichedReferenceRow,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_update_enrichment_level,
|
||||
count_active_siblings,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
convert_metadata_to_rows,
|
||||
@@ -31,16 +33,21 @@ from app.assets.database.queries.asset_reference import (
|
||||
get_unenriched_references,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
insert_reference,
|
||||
list_all_file_paths_by_asset_id,
|
||||
list_references_by_asset_id,
|
||||
list_references_page,
|
||||
mark_references_missing_outside_prefixes,
|
||||
rebuild_metadata_projection,
|
||||
reference_exists,
|
||||
reference_exists_for_asset_id,
|
||||
restore_references_by_paths,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_system_metadata,
|
||||
soft_delete_reference_by_id,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_is_missing_by_asset_id,
|
||||
update_reference_timestamps,
|
||||
update_reference_updated_at,
|
||||
upsert_reference,
|
||||
@@ -54,6 +61,7 @@ from app.assets.database.queries.tags import (
|
||||
bulk_insert_tags_and_meta,
|
||||
ensure_tags_exist,
|
||||
get_reference_tags,
|
||||
list_tag_counts_for_filtered_assets,
|
||||
list_tags_with_usage,
|
||||
remove_missing_tag_for_asset_id,
|
||||
remove_tags_from_reference,
|
||||
@@ -74,6 +82,8 @@ __all__ = [
|
||||
"bulk_insert_references_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_enrichment_level",
|
||||
"count_active_siblings",
|
||||
"create_stub_asset",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"convert_metadata_to_rows",
|
||||
@@ -97,20 +107,26 @@ __all__ = [
|
||||
"get_unenriched_references",
|
||||
"get_unreferenced_unhashed_asset_ids",
|
||||
"insert_reference",
|
||||
"list_all_file_paths_by_asset_id",
|
||||
"list_references_by_asset_id",
|
||||
"list_references_page",
|
||||
"list_tag_counts_for_filtered_assets",
|
||||
"list_tags_with_usage",
|
||||
"mark_references_missing_outside_prefixes",
|
||||
"reassign_asset_references",
|
||||
"rebuild_metadata_projection",
|
||||
"reference_exists",
|
||||
"reference_exists_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_tags_from_reference",
|
||||
"restore_references_by_paths",
|
||||
"set_reference_metadata",
|
||||
"set_reference_preview",
|
||||
"set_reference_system_metadata",
|
||||
"soft_delete_reference_by_id",
|
||||
"set_reference_tags",
|
||||
"update_asset_hash_and_mime",
|
||||
"update_is_missing_by_asset_id",
|
||||
"update_reference_access_time",
|
||||
"update_reference_name",
|
||||
"update_reference_timestamps",
|
||||
|
||||
@@ -69,7 +69,7 @@ def upsert_asset(
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
if mime_type and not asset.mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
@@ -78,6 +78,18 @@ def upsert_asset(
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def create_stub_asset(
|
||||
session: Session,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> Asset:
|
||||
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
@@ -118,7 +130,7 @@ def update_asset_hash_and_mime(
|
||||
return False
|
||||
if asset_hash is not None:
|
||||
asset.hash = asset_hash
|
||||
if mime_type is not None:
|
||||
if mime_type is not None and not asset.mime_type:
|
||||
asset.mime_type = mime_type
|
||||
return True
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from decimal import Decimal
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, exists, select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, noload
|
||||
@@ -24,12 +24,14 @@ from app.assets.database.models import (
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
MAX_BIND_PARAMS,
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
build_prefix_like_conditions,
|
||||
build_visible_owner_clause,
|
||||
calculate_rows_per_statement,
|
||||
iter_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now
|
||||
|
||||
|
||||
def _check_is_scalar(v):
|
||||
@@ -44,15 +46,6 @@ def _check_is_scalar(v):
|
||||
|
||||
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||
"""Convert a scalar value to a typed projection row."""
|
||||
if value is None:
|
||||
return {
|
||||
"key": key,
|
||||
"ordinal": ordinal,
|
||||
"val_str": None,
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
if isinstance(value, bool):
|
||||
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
@@ -66,96 +59,19 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
||||
"""Turn a metadata key/value into typed projection rows."""
|
||||
if value is None:
|
||||
return [_scalar_to_row(key, 0, None)]
|
||||
return []
|
||||
|
||||
if _check_is_scalar(value):
|
||||
return [_scalar_to_row(key, 0, value)]
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(_check_is_scalar(x) for x in value):
|
||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
|
||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
|
||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
|
||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
|
||||
|
||||
return [{"key": key, "ordinal": 0, "val_json": value}]
|
||||
|
||||
|
||||
def _apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_reference_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetReferenceMeta.val_json.is_(None),
|
||||
AssetReferenceMeta.val_str.is_(None),
|
||||
AssetReferenceMeta.val_num.is_(None),
|
||||
AssetReferenceMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def get_reference_by_id(
|
||||
@@ -198,6 +114,23 @@ def get_reference_by_file_path(
|
||||
)
|
||||
|
||||
|
||||
def count_active_siblings(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
exclude_reference_id: str,
|
||||
) -> int:
|
||||
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||
return (
|
||||
session.query(AssetReference)
|
||||
.filter(
|
||||
AssetReference.asset_id == asset_id,
|
||||
AssetReference.id != exclude_reference_id,
|
||||
AssetReference.deleted_at.is_(None),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
def reference_exists_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
@@ -212,6 +145,21 @@ def reference_exists_for_asset_id(
|
||||
return session.execute(q).first() is not None
|
||||
|
||||
|
||||
def reference_exists(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
) -> bool:
|
||||
"""Return True if a reference with the given ID exists (not soft-deleted)."""
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetReference)
|
||||
.where(AssetReference.id == reference_id)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.limit(1)
|
||||
)
|
||||
return session.execute(q).first() is not None
|
||||
|
||||
|
||||
def insert_reference(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
@@ -336,8 +284,8 @@ def list_references_page(
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
@@ -366,8 +314,8 @@ def list_references_page(
|
||||
count_stmt = count_stmt.where(
|
||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||
)
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int(session.execute(count_stmt).scalar_one() or 0)
|
||||
refs = session.execute(base).unique().scalars().all()
|
||||
@@ -379,7 +327,7 @@ def list_references_page(
|
||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||
.order_by(AssetReferenceTag.added_at)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
for ref_id, tag_name in rows.all():
|
||||
tag_map[ref_id].append(tag_name)
|
||||
@@ -492,6 +440,42 @@ def update_reference_updated_at(
|
||||
)
|
||||
|
||||
|
||||
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
|
||||
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
|
||||
|
||||
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
|
||||
override system keys of the same name.
|
||||
"""
|
||||
session.execute(
|
||||
delete(AssetReferenceMeta).where(
|
||||
AssetReferenceMeta.asset_reference_id == ref.id
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
|
||||
if not merged:
|
||||
return
|
||||
|
||||
rows: list[AssetReferenceMeta] = []
|
||||
for k, v in merged.items():
|
||||
for r in convert_metadata_to_rows(k, v):
|
||||
rows.append(
|
||||
AssetReferenceMeta(
|
||||
asset_reference_id=ref.id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def set_reference_metadata(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
@@ -505,33 +489,24 @@ def set_reference_metadata(
|
||||
ref.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
|
||||
session.execute(
|
||||
delete(AssetReferenceMeta).where(
|
||||
AssetReferenceMeta.asset_reference_id == reference_id
|
||||
)
|
||||
)
|
||||
rebuild_metadata_projection(session, ref)
|
||||
|
||||
|
||||
def set_reference_system_metadata(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
system_metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Set system_metadata on a reference and rebuild the merged projection."""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
ref.system_metadata = system_metadata or {}
|
||||
ref.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetReferenceMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in convert_metadata_to_rows(k, v):
|
||||
rows.append(
|
||||
AssetReferenceMeta(
|
||||
asset_reference_id=reference_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
rebuild_metadata_projection(session, ref)
|
||||
|
||||
|
||||
def delete_reference_by_id(
|
||||
@@ -571,19 +546,19 @@ def soft_delete_reference_by_id(
|
||||
def set_reference_preview(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
preview_reference_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
if preview_reference_id is None:
|
||||
ref.preview_id = None
|
||||
else:
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
ref.preview_id = preview_asset_id
|
||||
if not session.get(AssetReference, preview_reference_id):
|
||||
raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
|
||||
ref.preview_id = preview_reference_id
|
||||
|
||||
ref.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
@@ -609,6 +584,8 @@ def list_references_by_asset_id(
|
||||
session.execute(
|
||||
select(AssetReference)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.is_missing == False) # noqa: E712
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.order_by(AssetReference.id.asc())
|
||||
)
|
||||
.scalars()
|
||||
@@ -616,6 +593,25 @@ def list_references_by_asset_id(
|
||||
)
|
||||
|
||||
|
||||
def list_all_file_paths_by_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
) -> list[str]:
|
||||
"""Return every file_path for an asset, including soft-deleted/missing refs.
|
||||
|
||||
Used for orphan cleanup where all on-disk files must be removed.
|
||||
"""
|
||||
return list(
|
||||
session.execute(
|
||||
select(AssetReference.file_path)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.file_path.isnot(None))
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def upsert_reference(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
@@ -855,6 +851,22 @@ def bulk_update_is_missing(
|
||||
return total
|
||||
|
||||
|
||||
def update_is_missing_by_asset_id(
|
||||
session: Session, asset_id: str, value: bool
|
||||
) -> int:
|
||||
"""Set is_missing flag for ALL references belonging to an asset.
|
||||
|
||||
Returns: Number of rows updated
|
||||
"""
|
||||
result = session.execute(
|
||||
sa.update(AssetReference)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.values(is_missing=value)
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
||||
"""Delete references by their IDs.
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Shared utilities for database query modules."""
|
||||
|
||||
import os
|
||||
from typing import Iterable
|
||||
from decimal import Decimal
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from app.assets.database.models import AssetReference
|
||||
from app.assets.helpers import escape_sql_like_string
|
||||
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
|
||||
from app.assets.helpers import escape_sql_like_string, normalize_tags
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
@@ -52,3 +54,74 @@ def build_prefix_like_conditions(
|
||||
escaped, esc = escape_sql_like_string(base)
|
||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||
return conds
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_reference_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
return sa.not_(
|
||||
sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
@@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import (
|
||||
Asset,
|
||||
AssetReference,
|
||||
AssetReferenceMeta,
|
||||
AssetReferenceTag,
|
||||
Tag,
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
@@ -72,9 +75,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id
|
||||
)
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
).all()
|
||||
]
|
||||
@@ -117,7 +120,7 @@ def set_reference_tags(
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
||||
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
|
||||
|
||||
|
||||
def add_tags_to_reference(
|
||||
@@ -272,6 +275,12 @@ def list_tags_with_usage(
|
||||
.select_from(AssetReferenceTag)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.is_missing == False, # noqa: E712
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.subquery()
|
||||
@@ -308,6 +317,12 @@ def list_tags_with_usage(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.is_missing == False, # noqa: E712
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
)
|
||||
@@ -320,6 +335,53 @@ def list_tags_with_usage(
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def list_tag_counts_for_filtered_assets(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
"""Return tag counts for assets matching the given filters.
|
||||
|
||||
Uses the same filtering logic as list_references_page but returns
|
||||
{tag_name: count} instead of paginated references.
|
||||
"""
|
||||
# Build a subquery of matching reference IDs
|
||||
ref_sq = (
|
||||
select(AssetReference.id)
|
||||
.join(Asset, Asset.id == AssetReference.asset_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(AssetReference.is_missing == False) # noqa: E712
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
|
||||
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
|
||||
ref_sq = ref_sq.subquery()
|
||||
|
||||
# Count tags across those references
|
||||
q = (
|
||||
select(
|
||||
AssetReferenceTag.tag_name,
|
||||
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||
)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
rows = session.execute(q).all()
|
||||
return {tag_name: int(cnt) for tag_name, cnt in rows}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
tag_rows: list[dict],
|
||||
|
||||
@@ -13,12 +13,13 @@ from app.assets.database.queries import (
|
||||
delete_references_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_id,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
mark_references_missing_outside_prefixes,
|
||||
reassign_asset_references,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_system_metadata,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
from app.assets.services.bulk_ingest import (
|
||||
@@ -338,6 +339,7 @@ def build_asset_specs(
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
"job_id": None,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
@@ -426,6 +428,7 @@ def enrich_asset(
|
||||
except OSError:
|
||||
return new_level
|
||||
|
||||
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
@@ -489,9 +492,21 @@ def enrich_asset(
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||
# started (e.g. ingest_existing_file updated it), our results are
|
||||
# stale — discard them to avoid overwriting fresh registration data.
|
||||
ref = get_reference_by_id(session, reference_id)
|
||||
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||
session.rollback()
|
||||
logging.info(
|
||||
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||
reference_id,
|
||||
)
|
||||
return ENRICHMENT_STUB
|
||||
|
||||
if extract_metadata and metadata:
|
||||
user_metadata = metadata.to_user_metadata()
|
||||
set_reference_metadata(session, reference_id, user_metadata)
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
if full_hash:
|
||||
existing = get_asset_by_hash(session, full_hash)
|
||||
|
||||
@@ -77,7 +77,9 @@ class _AssetSeeder:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
# RLock is required because _run_scan() drains pending work while
|
||||
# holding _lock and re-enters start() which also acquires _lock.
|
||||
self._lock = threading.RLock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._last_progress: Progress | None = None
|
||||
@@ -92,6 +94,7 @@ class _AssetSeeder:
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
self._pending_enrich: dict | None = None
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
@@ -196,6 +199,42 @@ class _AssetSeeder:
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def enqueue_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||
|
||||
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||
request is stored and will run automatically when the current scan
|
||||
finishes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if started immediately, False if queued for later
|
||||
"""
|
||||
with self._lock:
|
||||
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||
return True
|
||||
if self._pending_enrich is not None:
|
||||
existing_roots = set(self._pending_enrich["roots"])
|
||||
existing_roots.update(roots)
|
||||
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||
self._pending_enrich["compute_hashes"] = (
|
||||
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||
)
|
||||
else:
|
||||
self._pending_enrich = {
|
||||
"roots": roots,
|
||||
"compute_hashes": compute_hashes,
|
||||
}
|
||||
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||
return False
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
@@ -381,9 +420,13 @@ class _AssetSeeder:
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
|
||||
def _reset_to_idle(self) -> None:
|
||||
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
@@ -594,9 +637,18 @@ class _AssetSeeder:
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
self._reset_to_idle()
|
||||
pending = self._pending_enrich
|
||||
if pending is not None:
|
||||
self._pending_enrich = None
|
||||
if not self.start_enrich(
|
||||
roots=pending["roots"],
|
||||
compute_hashes=pending["compute_hashes"],
|
||||
):
|
||||
logging.warning(
|
||||
"Pending enrich scan could not start (roots=%s)",
|
||||
pending["roots"],
|
||||
)
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
@@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
ingest_existing_file,
|
||||
register_output_files,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
@@ -72,6 +74,8 @@ __all__ = [
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"ingest_existing_file",
|
||||
"register_output_files",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
|
||||
@@ -16,10 +16,12 @@ from app.assets.database.queries import (
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
list_references_page,
|
||||
list_all_file_paths_by_asset_id,
|
||||
list_references_by_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_reference_updated_at,
|
||||
@@ -67,6 +69,8 @@ def update_asset_metadata(
|
||||
user_metadata: UserMetadata = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
@@ -103,6 +107,21 @@ def update_asset_metadata(
|
||||
)
|
||||
touched = True
|
||||
|
||||
if mime_type is not None:
|
||||
updated = update_asset_hash_and_mime(
|
||||
session, asset_id=ref.asset_id, mime_type=mime_type
|
||||
)
|
||||
if updated:
|
||||
touched = True
|
||||
|
||||
if preview_id is not None:
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_reference_id=preview_id,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
update_reference_updated_at(session, reference_id=reference_id)
|
||||
|
||||
@@ -159,11 +178,9 @@ def delete_asset_reference(
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# Orphaned asset - delete it and its files
|
||||
refs = list_references_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [
|
||||
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
|
||||
]
|
||||
# Orphaned asset - gather ALL file paths (including
|
||||
# soft-deleted / missing refs) so their on-disk files get cleaned up.
|
||||
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
|
||||
# Also include the just-deleted file path
|
||||
if file_path:
|
||||
file_paths.append(file_path)
|
||||
@@ -185,7 +202,7 @@ def delete_asset_reference(
|
||||
|
||||
def set_asset_preview(
|
||||
reference_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
preview_reference_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
@@ -194,7 +211,7 @@ def set_asset_preview(
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
preview_reference_id=preview_reference_id,
|
||||
)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
@@ -263,6 +280,47 @@ def list_assets_page(
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
asset_hash: str,
|
||||
owner_id: str = "",
|
||||
) -> DownloadResolutionResult | None:
|
||||
"""Resolve a blake3 hash to an on-disk file path.
|
||||
|
||||
Only references visible to *owner_id* are considered (owner-less
|
||||
references are always visible).
|
||||
|
||||
Returns a DownloadResolutionResult with abs_path, content_type, and
|
||||
download_name, or None if no asset or live path is found.
|
||||
"""
|
||||
with create_session() as session:
|
||||
asset = queries_get_asset_by_hash(session, asset_hash)
|
||||
if not asset:
|
||||
return None
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
visible = [
|
||||
r for r in refs
|
||||
if r.owner_id == "" or r.owner_id == owner_id
|
||||
]
|
||||
abs_path = select_best_live_path(visible)
|
||||
if not abs_path:
|
||||
return None
|
||||
display_name = os.path.basename(abs_path)
|
||||
for ref in visible:
|
||||
if ref.file_path == abs_path and ref.name:
|
||||
display_name = ref.name
|
||||
break
|
||||
ctype = (
|
||||
asset.mime_type
|
||||
or mimetypes.guess_type(display_name)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
return DownloadResolutionResult(
|
||||
abs_path=abs_path,
|
||||
content_type=ctype,
|
||||
download_name=display_name,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_for_download(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
|
||||
@@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
job_id: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
@@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
|
||||
name: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
job_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
@@ -167,6 +169,7 @@ def batch_insert_seed_assets(
|
||||
"name": spec["info_name"],
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"job_id": spec.get("job_id"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
|
||||
@@ -9,23 +9,29 @@ from sqlalchemy.orm import Session
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
count_active_siblings,
|
||||
create_stub_asset,
|
||||
ensure_tags_exist,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
reference_exists,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.helpers import get_utc_now, normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
@@ -65,7 +71,7 @@ def _ingest_file_from_path(
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
||||
if not reference_exists(session, preview_id):
|
||||
preview_id = None
|
||||
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
@@ -128,6 +134,102 @@ def _ingest_file_from_path(
|
||||
)
|
||||
|
||||
|
||||
def register_output_files(
|
||||
file_paths: Sequence[str],
|
||||
user_metadata: UserMetadata = None,
|
||||
job_id: str | None = None,
|
||||
) -> int:
|
||||
"""Register a batch of output file paths as assets.
|
||||
|
||||
Returns the number of files successfully registered.
|
||||
"""
|
||||
registered = 0
|
||||
for abs_path in file_paths:
|
||||
if not os.path.isfile(abs_path):
|
||||
continue
|
||||
try:
|
||||
if ingest_existing_file(
|
||||
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||
):
|
||||
registered += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to register output: %s", abs_path)
|
||||
return registered
|
||||
|
||||
|
||||
def ingest_existing_file(
|
||||
abs_path: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
extra_tags: Sequence[str] = (),
|
||||
owner_id: str = "",
|
||||
job_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Register an existing on-disk file as an asset stub.
|
||||
|
||||
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||
|
||||
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||
UX visibility.
|
||||
|
||||
Returns True if a row was inserted or updated, False otherwise.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||
|
||||
with create_session() as session:
|
||||
existing_ref = get_reference_by_file_path(session, locator)
|
||||
if existing_ref is not None:
|
||||
now = get_utc_now()
|
||||
existing_ref.mtime_ns = mtime_ns
|
||||
existing_ref.job_id = job_id
|
||||
existing_ref.is_missing = False
|
||||
existing_ref.deleted_at = None
|
||||
existing_ref.updated_at = now
|
||||
existing_ref.enrichment_level = 0
|
||||
|
||||
asset = existing_ref.asset
|
||||
if asset:
|
||||
# If other refs share this asset, detach to a new stub
|
||||
# instead of mutating the shared row.
|
||||
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||
if siblings > 0:
|
||||
new_asset = create_stub_asset(
|
||||
session,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type or asset.mime_type,
|
||||
)
|
||||
existing_ref.asset_id = new_asset.id
|
||||
else:
|
||||
asset.hash = None
|
||||
asset.size_bytes = size_bytes
|
||||
if mime_type:
|
||||
asset.mime_type = mime_type
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
spec = {
|
||||
"abs_path": abs_path,
|
||||
"size_bytes": size_bytes,
|
||||
"mtime_ns": mtime_ns,
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": os.path.basename(abs_path),
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": mime_type,
|
||||
"job_id": job_id,
|
||||
}
|
||||
if tags:
|
||||
ensure_tags_exist(session, tags)
|
||||
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||
session.commit()
|
||||
return result.won_paths > 0
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
@@ -135,6 +237,8 @@ def _register_existing_asset(
|
||||
tags: list[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> RegisterAssetResult:
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
@@ -143,14 +247,25 @@ def _register_existing_asset(
|
||||
if not asset:
|
||||
raise ValueError(f"No asset with hash {asset_hash}")
|
||||
|
||||
if mime_type and not asset.mime_type:
|
||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
|
||||
|
||||
if preview_id:
|
||||
if not reference_exists(session, preview_id):
|
||||
preview_id = None
|
||||
|
||||
ref, ref_created = get_or_create_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
|
||||
if not ref_created:
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
@@ -242,6 +357,8 @@ def upload_from_temp_path(
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_hash: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> UploadResult:
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||
@@ -270,6 +387,8 @@ def upload_from_temp_path(
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
mime_type=mime_type,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
@@ -291,7 +410,7 @@ def upload_from_temp_path(
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
validate_path_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
content_type = mime_type or (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
@@ -315,7 +434,7 @@ def upload_from_temp_path(
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
preview_id=preview_id,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags,
|
||||
tag_origin="manual",
|
||||
@@ -342,30 +461,99 @@ def upload_from_temp_path(
|
||||
)
|
||||
|
||||
|
||||
def register_file_in_place(
|
||||
abs_path: str,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
) -> UploadResult:
|
||||
"""Register an already-saved file in the asset database without moving it.
|
||||
|
||||
Tags are derived from the filesystem path (root category + subfolder names),
|
||||
merged with any caller-provided tags, matching the behavior of the scanner.
|
||||
If the path is not under a known root, only the caller-provided tags are used.
|
||||
"""
|
||||
try:
|
||||
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
except ValueError:
|
||||
path_tags = []
|
||||
merged_tags = normalize_tags([*path_tags, *tags])
|
||||
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(abs_path)
|
||||
except ImportError as e:
|
||||
raise DependencyMissingError(str(e))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
content_type = mime_type or (
|
||||
mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
ingest_result = _ingest_file_from_path(
|
||||
abs_path=abs_path,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
tags=merged_tags,
|
||||
tag_origin="upload",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
reference_id = ingest_result.reference_id
|
||||
if not reference_id:
|
||||
raise RuntimeError("failed to create asset reference")
|
||||
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
ref, asset = pair
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
|
||||
return UploadResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created_new=ingest_result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
def create_from_hash(
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> UploadResult | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
try:
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
mime_type=mime_type,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
except ValueError:
|
||||
logging.warning("create_from_hash: no asset found for hash %s", canonical)
|
||||
return None
|
||||
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
|
||||
@@ -25,7 +25,9 @@ class ReferenceData:
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime | None
|
||||
system_metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
||||
file_path=ref.file_path,
|
||||
user_metadata=ref.user_metadata,
|
||||
preview_id=ref.preview_id,
|
||||
system_metadata=ref.system_metadata,
|
||||
job_id=ref.job_id,
|
||||
created_at=ref.created_at,
|
||||
updated_at=ref.updated_at,
|
||||
last_access_time=ref.last_access_time,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
@@ -6,6 +8,7 @@ from app.assets.database.queries import (
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_reference,
|
||||
)
|
||||
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
|
||||
from app.assets.services.schemas import TagUsage
|
||||
from app.database.db import create_session
|
||||
|
||||
@@ -73,3 +76,23 @@ def list_tags(
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
|
||||
|
||||
def list_tag_histogram(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
with create_session() as session:
|
||||
return list_tag_counts_for_filtered_assets(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
NAMING_CONVENTION = {
|
||||
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
|
||||
@@ -6,6 +6,7 @@ import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
@@ -377,8 +378,15 @@ class UserManager():
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
dir_name = os.path.dirname(path)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(body)
|
||||
os.replace(tmp_path, path)
|
||||
except:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -149,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@@ -262,4 +263,6 @@ else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
if args.enable_dynamic_vram:
|
||||
return True
|
||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||
|
||||
@@ -93,6 +93,50 @@ class IndexListCallbacks:
|
||||
return {}
|
||||
|
||||
|
||||
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
|
||||
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
|
||||
return None
|
||||
cond_tensor = cond_value.cond
|
||||
if temporal_dim >= cond_tensor.ndim:
|
||||
return None
|
||||
|
||||
cond_size = cond_tensor.size(temporal_dim)
|
||||
|
||||
if temporal_scale == 1:
|
||||
expected_size = x_in.size(window.dim) - temporal_offset
|
||||
if cond_size != expected_size:
|
||||
return None
|
||||
|
||||
if temporal_offset == 0 and temporal_scale == 1:
|
||||
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
if temporal_scale > 1:
|
||||
scaled = []
|
||||
for i in indices:
|
||||
for k in range(temporal_scale):
|
||||
si = i * temporal_scale + k
|
||||
if si < cond_size:
|
||||
scaled.append(si)
|
||||
indices = scaled
|
||||
if not indices:
|
||||
return None
|
||||
|
||||
idx = tuple([slice(None)] * temporal_dim + [indices])
|
||||
sliced = cond_tensor[idx].to(device)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if not handled and self._model is not None:
|
||||
result = self._model.resize_cond_for_context_window(
|
||||
cond_key, cond_value, window, x_in, device,
|
||||
retain_index_list=self.cond_retain_index_list)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
@@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
@@ -136,16 +136,7 @@ class ResBlock(nn.Module):
|
||||
ops.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
@@ -386,7 +386,7 @@ class Flux(nn.Module):
|
||||
h = max(h, ref.shape[-2] + h_offset)
|
||||
w = max(w, ref.shape[-1] + w_offset)
|
||||
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
|
||||
@@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
@@ -412,6 +413,7 @@ class Attention(nn.Module):
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
@@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
|
||||
# Inject reference audio for ID-LoRA in-context conditioning
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
ref_audio_seq_len = 0
|
||||
if ref_audio is not None:
|
||||
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
|
||||
if ref_tokens.shape[0] < ax.shape[0]:
|
||||
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
|
||||
ref_audio_seq_len = ref_tokens.shape[1]
|
||||
B = ax.shape[0]
|
||||
|
||||
# Compute negative temporal positions matching ID-LoRA convention:
|
||||
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
|
||||
p = self.a_patchifier
|
||||
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
|
||||
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
|
||||
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
|
||||
time_offset = ref_end[-1].item() + tpl
|
||||
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
|
||||
|
||||
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
||||
additional_args["target_audio_seq_len"] = ax.shape[1]
|
||||
ax = torch.cat([ref_tokens, ax], dim=1)
|
||||
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
||||
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
|
||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||
@@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||
if ref_audio_seq_len > 0 and a_timestep is not None:
|
||||
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
||||
target_len = kwargs.get("target_audio_seq_len")
|
||||
if a_timestep.dim() <= 1:
|
||||
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
||||
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
|
||||
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
||||
if a_timestep is not None:
|
||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||
a_timestep_flat = a_timestep_scaled.flatten()
|
||||
@@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
|
||||
v_embedded_timestep = embedded_timestep[0]
|
||||
a_embedded_timestep = embedded_timestep[1]
|
||||
|
||||
# Trim reference audio tokens before unpatchification
|
||||
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||
if ref_audio_seq_len > 0:
|
||||
ax = ax[:, ref_audio_seq_len:]
|
||||
if a_embedded_timestep.shape[1] > 1:
|
||||
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
|
||||
|
||||
# Expand compressed video timestep if needed
|
||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||
v_embedded_timestep = v_embedded_timestep.expand()
|
||||
|
||||
@@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if isinstance(stride, int):
|
||||
self.time_stride = stride
|
||||
else:
|
||||
self.time_stride = stride[0]
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
@@ -58,16 +63,25 @@ class CausalConv3d(nn.Module):
|
||||
pieces = [ cached, x ]
|
||||
if is_end and not causal:
|
||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||
input_length = sum([piece.shape[2] for piece in pieces])
|
||||
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
|
||||
|
||||
needs_caching = not is_end
|
||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
||||
if needs_caching and cache_length == 0:
|
||||
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
if needs_caching and x.shape[2] >= cache_length:
|
||||
needs_caching = False
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
del pieces
|
||||
del cached
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||
elif is_end:
|
||||
self.temporal_cache_state[tid] = (None, True)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
|
||||
@@ -233,10 +233,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
checkpoint_fn = (
|
||||
@@ -247,10 +244,14 @@ class Encoder(nn.Module):
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
sample = checkpoint_fn(down_block)(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return None
|
||||
|
||||
if self.latent_log_var == "uniform":
|
||||
last_channel = sample[:, -1:, ...]
|
||||
@@ -282,9 +283,35 @@ class Encoder(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
|
||||
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
|
||||
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
|
||||
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
|
||||
|
||||
outputs = []
|
||||
samples = [sample[:, :, :1, :, :]]
|
||||
if sample.shape[2] > 1:
|
||||
chunk_t = max(2, max_chunk_size // frame_size)
|
||||
if chunk_t < 4:
|
||||
chunk_t = 2
|
||||
elif chunk_t < 8:
|
||||
chunk_t = 4
|
||||
else:
|
||||
chunk_t = (chunk_t // 8) * 8
|
||||
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
|
||||
for chunk_idx, chunk in enumerate(samples):
|
||||
if chunk_idx == len(samples) - 1:
|
||||
mark_conv3d_ended(self)
|
||||
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
|
||||
output = self._forward_chunk(chunk)
|
||||
if output is not None:
|
||||
outputs.append(output)
|
||||
|
||||
return torch_cat_if_needed(outputs, dim=2)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
#No encoder support so just flag the end so it doesnt use the cache.
|
||||
mark_conv3d_ended(self)
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
finally:
|
||||
@@ -297,7 +324,23 @@ class Encoder(nn.Module):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||
|
||||
def get_max_chunk_size(device: torch.device) -> int:
|
||||
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||
|
||||
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||
return MIN_CHUNK_SIZE
|
||||
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||
return MAX_CHUNK_SIZE
|
||||
|
||||
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||
)
|
||||
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@@ -457,6 +500,17 @@ class Decoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
|
||||
ts, hs, ws, to = 1, 1, 1, 0
|
||||
for block in self.up_blocks:
|
||||
if isinstance(block, DepthToSpaceUpsample):
|
||||
ts *= block.stride[0]
|
||||
hs *= block.stride[1]
|
||||
ws *= block.stride[2]
|
||||
if block.stride[0] > 1:
|
||||
to = to * block.stride[0] + 1
|
||||
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
@@ -478,11 +532,62 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def decode_output_shape(self, input_shape):
|
||||
c, (ts, hs, ws), to = self._output_scale
|
||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||
|
||||
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
|
||||
sample = sample_ref[0]
|
||||
sample_ref[0] = None
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
t = sample.shape[2]
|
||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||
output_offset[0] += t
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if ended:
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||
|
||||
if num_chunks == 1:
|
||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||
next_sample_ref = [sample]
|
||||
del sample
|
||||
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
return
|
||||
else:
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
output_buffer: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
@@ -497,6 +602,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
timestep_shift_scale = None
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
@@ -524,48 +630,18 @@ class Decoder(nn.Module):
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
output = []
|
||||
if output_buffer is None:
|
||||
output_buffer = torch.empty(
|
||||
self.decode_output_shape(sample.shape),
|
||||
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
shift, scale = timestep_shift_scale
|
||||
sample = sample * (1 + scale) + shift
|
||||
sample = self.conv_act(sample)
|
||||
if ended:
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||
return
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
if (ended):
|
||||
mark_conv3d_ended(up_block)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
|
||||
if sample is None or sample.shape[2] == 0:
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
return output_buffer
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
@@ -689,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.temporal_cache_state = {}
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
if self.stride[0] == 2:
|
||||
tid = threading.get_ident()
|
||||
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
|
||||
if cached_input is not None:
|
||||
x = torch_cat_if_needed([cached_input, x], dim=2)
|
||||
cached_input = None
|
||||
|
||||
if self.stride[0] == 2 and pad_first:
|
||||
x = torch.cat(
|
||||
[x[:, :, :1, :, :], x], dim=2
|
||||
) # duplicate first frames for padding
|
||||
pad_first = False
|
||||
|
||||
if x.shape[2] < self.stride[0]:
|
||||
cached_input = x
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
return None
|
||||
|
||||
# skip connection
|
||||
x_in = rearrange(
|
||||
@@ -709,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
# conv
|
||||
x = self.conv(x, causal=causal)
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
if self.stride[0] == 2 and x.shape[2] == 1:
|
||||
if cached_x is not None:
|
||||
x = torch_cat_if_needed([cached_x, x], dim=2)
|
||||
cached_x = None
|
||||
else:
|
||||
cached_x = x
|
||||
x = None
|
||||
|
||||
x = x + x_in
|
||||
if x is not None:
|
||||
x = rearrange(
|
||||
x,
|
||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||
p1=self.stride[0],
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
|
||||
cached = add_exchange_cache(x, cached, x_in, dim=2)
|
||||
|
||||
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||
|
||||
return x
|
||||
|
||||
@@ -1050,6 +1150,8 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
comfy_has_chunked_io = True
|
||||
|
||||
def __init__(self, version=0, config=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -1192,14 +1294,15 @@ class VideoVAE(nn.Module):
|
||||
}
|
||||
return config
|
||||
|
||||
def encode(self, x):
|
||||
frames_count = x.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
def encode(self, x, device=None):
|
||||
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
||||
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
def decode_output_shape(self, input_shape):
|
||||
return self.decoder.decode_output_shape(input_shape)
|
||||
|
||||
def decode(self, x, output_buffer=None):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
|
||||
|
||||
@@ -99,7 +99,7 @@ class Resample(nn.Module):
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
@@ -109,22 +109,7 @@ class Resample(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
@@ -145,19 +130,24 @@ class Resample(nn.Module):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
feat_cache[idx] = x
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
cache_x = x[:, :, -1:, :, :]
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
deferred_x = feat_cache[idx + 1]
|
||||
if deferred_x is not None:
|
||||
x = torch.cat([deferred_x, x], 2)
|
||||
feat_cache[idx + 1] = None
|
||||
|
||||
if x.shape[2] == 1 and not final:
|
||||
feat_cache[idx + 1] = x
|
||||
x = None
|
||||
|
||||
feat_idx[0] += 2
|
||||
return x
|
||||
|
||||
|
||||
@@ -177,19 +167,12 @@ class ResidualBlock(nn.Module):
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -213,7 +196,7 @@ class AttentionBlock(nn.Module):
|
||||
self.proj = ops.Conv2d(dim, dim, 1)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
@@ -283,17 +266,10 @@ class Encoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -303,14 +279,16 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if x is None:
|
||||
return None
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -318,14 +296,7 @@ class Encoder3d(nn.Module):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -389,18 +360,48 @@ class Decoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
|
||||
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
|
||||
x = x_ref[0]
|
||||
x_ref[0] = None
|
||||
if layer_idx >= len(self.upsamples):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[feat_idx[0]])
|
||||
feat_cache[feat_idx[0]] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
out_chunks.append(x)
|
||||
return
|
||||
|
||||
layer = self.upsamples[layer_idx]
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
|
||||
for frame_idx in range(0, x.shape[2], 2):
|
||||
self.run_up(
|
||||
layer_idx + 1,
|
||||
[x[:, :, frame_idx:frame_idx + 2, :, :]],
|
||||
feat_cache,
|
||||
feat_idx.copy(),
|
||||
out_chunks,
|
||||
)
|
||||
del x
|
||||
return
|
||||
|
||||
next_x_ref = [x]
|
||||
del x
|
||||
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -409,42 +410,21 @@ class Decoder3d(nn.Module):
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
out_chunks = []
|
||||
|
||||
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
|
||||
return out_chunks
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
def count_cache_layers(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -482,11 +462,12 @@ class WanVAE(nn.Module):
|
||||
conv_idx = [0]
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
t = 1 + ((t - 1) // 4) * 4
|
||||
iter_ = 1 + (t - 1) // 2
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
feat_map = [None] * count_cache_layers(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
@@ -496,20 +477,23 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
feat_idx=conv_idx,
|
||||
final=(i == (iter_ - 1)))
|
||||
if out_ is None:
|
||||
continue
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
# z: [b,c,t,h,w]
|
||||
iter_ = z.shape[2]
|
||||
iter_ = 1 + z.shape[2] // 2
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
feat_map = [None] * count_cache_layers(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@@ -520,8 +504,8 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
return out
|
||||
out += out_
|
||||
return torch.cat(out, 2)
|
||||
|
||||
@@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination):
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size):
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
or not tensor.is_contiguous()):
|
||||
return False
|
||||
|
||||
if info.size == 0:
|
||||
|
||||
@@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
@@ -285,6 +286,12 @@ class BaseModel(torch.nn.Module):
|
||||
return data
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
||||
Return a sliced cond object, or None to fall through to default handling.
|
||||
Use comfy.context_windows.slice_cond() for common cases."""
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@@ -930,9 +937,10 @@ class LongCatImage(Flux):
|
||||
transformer_options = transformer_options.copy()
|
||||
rope_opts = transformer_options.get("rope_options", {})
|
||||
rope_opts = dict(rope_opts)
|
||||
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
|
||||
rope_opts.setdefault("shift_t", 1.0)
|
||||
rope_opts.setdefault("shift_y", 512.0)
|
||||
rope_opts.setdefault("shift_x", 512.0)
|
||||
rope_opts.setdefault("shift_y", pe_len)
|
||||
rope_opts.setdefault("shift_x", pe_len)
|
||||
transformer_options["rope_options"] = rope_opts
|
||||
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
@@ -1053,6 +1061,10 @@ class LTXAV(BaseModel):
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
if ref_audio is not None:
|
||||
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
@@ -1375,6 +1387,11 @@ class WAN21_Vace(WAN21):
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "vace_context":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
@@ -1427,6 +1444,11 @@ class WAN21_HuMo(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_Animate(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
|
||||
@@ -1444,6 +1466,13 @@ class WAN22_Animate(WAN21):
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "face_pixel_values":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
|
||||
if cond_key == "pose_latents":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
@@ -1480,6 +1509,11 @@ class WAN22_S2V(WAN21):
|
||||
out['reference_motion'] = reference_motion.shape
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
|
||||
@@ -55,6 +55,7 @@ total_vram = 0
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
training_fp8_bwd = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
@@ -541,6 +542,7 @@ class LoadedModel:
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||
self._patcher_finalizer.atexit = False
|
||||
|
||||
def _switch_parent(self):
|
||||
model = self._parent_model()
|
||||
@@ -587,6 +589,7 @@ class LoadedModel:
|
||||
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
self.model_finalizer.atexit = False
|
||||
return real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
|
||||
134
comfy/ops.py
134
comfy/ops.py
@@ -776,6 +776,104 @@ from .quant_ops import (
|
||||
)
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||
|
||||
When training_fp8_bwd is enabled:
|
||||
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||
- Cached input is FP8 (half the memory of bf16)
|
||||
|
||||
When training_fp8_bwd is disabled:
|
||||
- Forward: quantize input per layout, use quantized matmul
|
||||
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input for forward (same layout as weight)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
q_input = inp
|
||||
|
||||
w = weight.detach() if weight.requires_grad else weight
|
||||
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Unflatten output to match original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
# Save for backward
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||
|
||||
if ctx.fp8_bwd:
|
||||
# Cache FP8 quantized input — half the memory of bf16
|
||||
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||
ctx.q_input = q_input # already FP8, reuse
|
||||
else:
|
||||
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||
ctx.save_for_backward(weight)
|
||||
else:
|
||||
ctx.q_input = None
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Value casting — only difference between fp8 and non-fp8 paths
|
||||
if ctx.fp8_bwd:
|
||||
weight, = ctx.saved_tensors
|
||||
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||
weight_mm = weight
|
||||
elif isinstance(weight, QuantizedTensor):
|
||||
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
else:
|
||||
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||
input_mm = ctx.q_input
|
||||
else:
|
||||
input_float, weight = ctx.saved_tensors
|
||||
# Standard tensors → torch.mm does regular matmul
|
||||
grad_mm = grad_2d
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_mm = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_mm = weight.to(compute_dtype)
|
||||
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||
|
||||
# Computation — same for both paths, dispatch handles the rest
|
||||
grad_input = torch.mm(grad_mm, weight_mm)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
@@ -970,10 +1068,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
_use_quantized = (
|
||||
getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True
|
||||
)
|
||||
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
|
||||
output = QuantLinearFunc.apply(
|
||||
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||
)
|
||||
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@@ -1021,7 +1146,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
p = fn(param)
|
||||
if p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
|
||||
@@ -8,12 +8,12 @@ import comfy.nested_tensor
|
||||
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return samples
|
||||
|
||||
@@ -985,8 +985,8 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
noise = noise.to(device)
|
||||
latent_image = latent_image.to(device)
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
@@ -1028,6 +1028,7 @@ class CFGGuider:
|
||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||
else:
|
||||
denoise_mask = denoise_masks[0]
|
||||
denoise_mask = denoise_mask.float()
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
|
||||
34
comfy/sd.py
34
comfy/sd.py
@@ -455,7 +455,7 @@ class VAE:
|
||||
self.output_channels = 3
|
||||
self.pad_channel_value = None
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
@@ -951,12 +951,23 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
@@ -967,6 +978,7 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
@@ -1027,8 +1039,13 @@ class VAE:
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
@@ -1043,6 +1060,7 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
|
||||
@@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
||||
out, pooled = o[:2]
|
||||
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
@@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
|
||||
else:
|
||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
|
||||
|
||||
if len(o) > 2:
|
||||
extra = {}
|
||||
for k in o[2]:
|
||||
v = o[2][k]
|
||||
if k == "attention_mask":
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
|
||||
extra[k] = v
|
||||
|
||||
r = r + (extra,)
|
||||
|
||||
@@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
if attention_mask is not None:
|
||||
# Assign compact sequential positions to attended tokens only,
|
||||
# skipping over padding so post-padding tokens aren't inflated.
|
||||
after_mask = attention_mask[0, end:]
|
||||
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
|
||||
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
|
||||
else:
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
|
||||
@@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
|
||||
return [output]
|
||||
|
||||
|
||||
IMAGE_PAD_TOKEN_ID = 151655
|
||||
|
||||
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
||||
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory,
|
||||
@@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
name="qwen25_7b",
|
||||
tokenizer=LongCatImageBaseTokenizer,
|
||||
)
|
||||
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
||||
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
|
||||
skip_template = False
|
||||
if text.startswith("<|im_start|>"):
|
||||
skip_template = True
|
||||
@@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
|
||||
)
|
||||
else:
|
||||
has_images = images is not None and len(images) > 0
|
||||
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
|
||||
|
||||
prefix_ids = base_tok.tokenizer(
|
||||
self.longcat_template_prefix, add_special_tokens=False
|
||||
template_prefix, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
suffix_ids = base_tok.tokenizer(
|
||||
self.longcat_template_suffix, add_special_tokens=False
|
||||
self.SUFFIX, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
|
||||
prompt_tokens = base_tok.tokenize_with_weights(
|
||||
@@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
suffix_pairs = [(t, 1.0) for t in suffix_ids]
|
||||
|
||||
combined = prefix_pairs + prompt_pairs + suffix_pairs
|
||||
|
||||
if has_images:
|
||||
embed_count = 0
|
||||
for i in range(len(combined)):
|
||||
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
|
||||
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
|
||||
embed_count += 1
|
||||
|
||||
tokens = {"qwen25_7b": [combined]}
|
||||
|
||||
return tokens
|
||||
|
||||
@@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module):
|
||||
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
# Potentially important for spatially precise edits. This is present in the HF implementation.
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
hidden_states = hidden_states[reverse_indices, :]
|
||||
return hidden_states
|
||||
|
||||
@@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
# Clone inference tensors (created under torch.inference_mode) since
|
||||
# their version counter is frozen and nn.Parameter() cannot wrap them.
|
||||
if (not torch.is_inference_mode_enabled()) and value.is_inference():
|
||||
value = value.clone()
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def set_attr_buffer(obj, attr, value):
|
||||
@@ -1131,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
out = output[b:b+1].zero_()
|
||||
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
|
||||
@@ -1147,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||
@@ -1170,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
output[b:b+1] = out/out_div
|
||||
out.div_(out_div)
|
||||
return output
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
|
||||
@@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
CurvePoint,
|
||||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -13,4 +17,8 @@ __all__ = [
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
@@ -7,4 +8,8 @@ __all__ = [
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"CurvePoint",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CurvePoint = tuple[float, float]
|
||||
|
||||
|
||||
class CurveInput(ABC):
|
||||
"""Abstract base class for curve inputs.
|
||||
|
||||
Subclasses represent different curve representations (control-point
|
||||
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||
uniform evaluation interface to downstream nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def points(self) -> list[CurvePoint]:
|
||||
"""The control points that define this curve."""
|
||||
|
||||
@abstractmethod
|
||||
def interp(self, x: float) -> float:
|
||||
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||
|
||||
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||
"""Vectorised evaluation over a numpy array of x values.
|
||||
|
||||
Subclasses should override this for better performance. The default
|
||||
falls back to scalar ``interp`` calls.
|
||||
"""
|
||||
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||
|
||||
@staticmethod
|
||||
def from_raw(data) -> CurveInput:
|
||||
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||
|
||||
Accepts:
|
||||
- A ``CurveInput`` instance (returned as-is).
|
||||
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||
"""
|
||||
if isinstance(data, CurveInput):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
raw_points = data["points"]
|
||||
interpolation = data.get("interpolation", "monotone_cubic")
|
||||
else:
|
||||
raw_points = data
|
||||
interpolation = "monotone_cubic"
|
||||
points = [(float(x), float(y)) for x, y in raw_points]
|
||||
if interpolation == "linear":
|
||||
return LinearCurve(points)
|
||||
if interpolation != "monotone_cubic":
|
||||
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||
return MonotoneCubicCurve(points)
|
||||
|
||||
|
||||
class MonotoneCubicCurve(CurveInput):
|
||||
"""Monotone cubic Hermite interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||
backend evaluation matches the editor preview exactly.
|
||||
|
||||
All heavy work (sorting, slope computation) happens once at construction.
|
||||
``interp_array`` is fully vectorised with numpy.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
self._slopes = self._compute_slopes()
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def _compute_slopes(self) -> np.ndarray:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return np.zeros(n, dtype=np.float64)
|
||||
|
||||
dx = np.diff(xs)
|
||||
dy = np.diff(ys)
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||
|
||||
slopes = np.empty(n, dtype=np.float64)
|
||||
slopes[0] = deltas[0]
|
||||
slopes[-1] = deltas[-1]
|
||||
for i in range(1, n - 1):
|
||||
if deltas[i - 1] * deltas[i] <= 0:
|
||||
slopes[i] = 0.0
|
||||
else:
|
||||
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||
|
||||
for i in range(n - 1):
|
||||
if deltas[i] == 0:
|
||||
slopes[i] = 0.0
|
||||
slopes[i + 1] = 0.0
|
||||
else:
|
||||
alpha = slopes[i] / deltas[i]
|
||||
beta = slopes[i + 1] / deltas[i]
|
||||
s = alpha * alpha + beta * beta
|
||||
if s > 9:
|
||||
t = 3 / math.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
return slopes
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
if x <= xs[0]:
|
||||
return float(ys[0])
|
||||
if x >= xs[-1]:
|
||||
return float(ys[-1])
|
||||
|
||||
hi = int(np.searchsorted(xs, x, side='right'))
|
||||
hi = min(hi, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
if dx == 0:
|
||||
return float(ys[lo])
|
||||
|
||||
t = (x - xs[lo]) / dx
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
"""Fully vectorised evaluation using numpy."""
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if n == 1:
|
||||
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||
|
||||
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
|
||||
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MonotoneCubicCurve(points={self._points})"
|
||||
|
||||
|
||||
class LinearCurve(CurveInput):
|
||||
"""Piecewise linear interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createLinearInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
return float(np.interp(x, xs, ys))
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
if len(self._xs) == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if len(self._xs) == 1:
|
||||
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||
return np.interp(xs_in, self._xs, self._ys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LinearCurve(points={self._points})"
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from comfy.samplers import CFGGuider, Sampler
|
||||
from comfy.sd import CLIP, VAE
|
||||
from comfy.sd import StyleModel as StyleModel_
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
@@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
Type = list[CurvePoint]
|
||||
from comfy_api.input import CurvePoint
|
||||
if TYPE_CHECKING:
|
||||
Type = CurveInput_
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
@@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.default is not None:
|
||||
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||
return d
|
||||
|
||||
|
||||
@comfytype(io_type="HISTOGRAM")
|
||||
class Histogram(ComfyTypeIO):
|
||||
"""A histogram represented as a list of bin counts."""
|
||||
Type = list[int]
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
@@ -2240,5 +2253,6 @@ __all__ = [
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"Curve",
|
||||
"Histogram",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
fileData: GeminiFileData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
thought: bool | None = Field(None)
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModel):
|
||||
|
||||
@@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
image: InputUrlObject | None = Field(...)
|
||||
image: InputUrlObject | None = Field(None)
|
||||
reference_images: list[InputUrlObject] | None = Field(None)
|
||||
duration: int = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class VideoExtensionRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
video: InputUrlObject = Field(...)
|
||||
duration: int = Field(default=6)
|
||||
model: str | None = Field(default=None)
|
||||
|
||||
|
||||
class VideoEditRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
|
||||
43
comfy_api_nodes/apis/quiver.py
Normal file
43
comfy_api_nodes/apis/quiver.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class QuiverImageObject(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class QuiverTextToSVGRequest(BaseModel):
|
||||
model: str = Field(default="arrow-preview")
|
||||
prompt: str = Field(...)
|
||||
instructions: str | None = Field(default=None)
|
||||
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
|
||||
temperature: float | None = Field(default=None, ge=0, le=2)
|
||||
top_p: float | None = Field(default=None, ge=0, le=1)
|
||||
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
|
||||
|
||||
|
||||
class QuiverImageToSVGRequest(BaseModel):
|
||||
model: str = Field(default="arrow-preview")
|
||||
image: QuiverImageObject = Field(...)
|
||||
auto_crop: bool | None = Field(default=None)
|
||||
target_size: int | None = Field(default=None, ge=128, le=4096)
|
||||
temperature: float | None = Field(default=None, ge=0, le=2)
|
||||
top_p: float | None = Field(default=None, ge=0, le=1)
|
||||
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
|
||||
|
||||
|
||||
class QuiverSVGResponseItem(BaseModel):
|
||||
svg: str = Field(...)
|
||||
mime_type: str | None = Field(default="image/svg+xml")
|
||||
|
||||
|
||||
class QuiverSVGUsage(BaseModel):
|
||||
total_tokens: int | None = Field(default=None)
|
||||
input_tokens: int | None = Field(default=None)
|
||||
output_tokens: int | None = Field(default=None)
|
||||
|
||||
|
||||
class QuiverSVGResponse(BaseModel):
|
||||
id: str | None = Field(default=None)
|
||||
created: int | None = Field(default=None)
|
||||
data: list[QuiverSVGResponseItem] = Field(...)
|
||||
usage: QuiverSVGUsage | None = Field(default=None)
|
||||
@@ -47,6 +47,10 @@ SEEDREAM_MODELS = {
|
||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
@@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03}""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
]
|
||||
return await process_video_task(
|
||||
cls,
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x),
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@@ -952,6 +957,12 @@ async def process_video_task(
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
if payload.model in DEPRECATED_MODELS:
|
||||
logger.warning(
|
||||
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||
payload.model,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
|
||||
@@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
||||
$m := widgets.model;
|
||||
$r := widgets.resolution;
|
||||
$isFlash := $contains($m, "nano banana 2");
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
|
||||
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
||||
$prices := $isFlash ? $flashPrices : $proPrices;
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
@@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/*")
|
||||
for part in parts:
|
||||
if (part.thought is True) != thought:
|
||||
continue
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
@@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
IO.Image.Output(
|
||||
display_name="thought_image",
|
||||
tooltip="First image from the model's thinking process. "
|
||||
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||
return IO.NodeOutput(
|
||||
await get_image_from_response(response),
|
||||
get_text_from_response(response),
|
||||
await get_image_from_response(response, thought=True),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
|
||||
@@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
|
||||
ImageGenerationResponse,
|
||||
InputUrlObject,
|
||||
VideoEditRequest,
|
||||
VideoExtensionRequest,
|
||||
VideoGenerationRequest,
|
||||
VideoGenerationResponse,
|
||||
VideoStatusResponse,
|
||||
@@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
@@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_grok_video_price(response) -> float | None:
|
||||
price = _extract_grok_price(response)
|
||||
if price is not None:
|
||||
return price * 1.43
|
||||
return None
|
||||
|
||||
|
||||
class GrokImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "grok-imagine-video-beta":
|
||||
model = "grok-imagine-video"
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
@@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoReferenceNode",
|
||||
display_name="Grok Reference-to-Video",
|
||||
category="api node/video/Grok",
|
||||
description="Generate video guided by reference images as style and content references.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of the desired video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="reference_",
|
||||
min=1,
|
||||
max=7,
|
||||
),
|
||||
tooltip="Up to 7 reference images to guide the video generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=6,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="The duration of the output video in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model.duration", "model.resolution"],
|
||||
input_groups=["model.reference_images"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$refs := inputGroups["model.reference_images"];
|
||||
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||
{"type":"usd","usd": $price}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
ref_image_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
list(model["reference_images"].values()),
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading base images",
|
||||
max_images=7,
|
||||
)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model["model"],
|
||||
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||
prompt=prompt,
|
||||
resolution=model["resolution"],
|
||||
duration=model["duration"],
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
seed=seed,
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokVideoExtendNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoExtendNode",
|
||||
display_name="Grok Video Extend",
|
||||
category="api node/video/Grok",
|
||||
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text description of what should happen next in the video.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"grok-imagine-video",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=8,
|
||||
min=2,
|
||||
max=10,
|
||||
step=1,
|
||||
tooltip="Length of the extension in seconds.",
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for video extension.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||
expr="""
|
||||
(
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||
video_size = get_fs_object_size(video.get_stream_source())
|
||||
if video_size > 50 * 1024 * 1024:
|
||||
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||
data=VideoExtensionRequest(
|
||||
prompt=prompt,
|
||||
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||
duration=model["duration"],
|
||||
),
|
||||
response_model=VideoGenerationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_video_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
|
||||
class GrokExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
|
||||
GrokImageNode,
|
||||
GrokImageEditNode,
|
||||
GrokVideoNode,
|
||||
GrokVideoReferenceNode,
|
||||
GrokVideoEditNode,
|
||||
GrokVideoExtendNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
291
comfy_api_nodes/nodes_quiver.py
Normal file
291
comfy_api_nodes/nodes_quiver.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from io import BytesIO
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.quiver import (
|
||||
QuiverImageObject,
|
||||
QuiverImageToSVGRequest,
|
||||
QuiverSVGResponse,
|
||||
QuiverTextToSVGRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_extras.nodes_images import SVG
|
||||
|
||||
|
||||
class QuiverTextToSVGNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="QuiverTextToSVGNode",
|
||||
display_name="Quiver Text to SVG",
|
||||
category="api node/image/Quiver",
|
||||
description="Generate an SVG from a text prompt using Quiver AI.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired SVG output.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"instructions",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Additional style or formatting guidance.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="ref_",
|
||||
min=0,
|
||||
max=4,
|
||||
),
|
||||
tooltip="Up to 4 reference images to guide the generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"arrow-preview",
|
||||
[
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Randomness control. Higher values increase randomness.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=1.0,
|
||||
min=0.05,
|
||||
max=1.0,
|
||||
step=0.05,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Nucleus sampling parameter.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"presence_penalty",
|
||||
default=0.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Token presence penalty.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for SVG generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.429}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
instructions: str = None,
|
||||
reference_images: IO.Autogrow.Type = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
|
||||
references = None
|
||||
if reference_images:
|
||||
references = []
|
||||
for key in reference_images:
|
||||
url = await upload_image_to_comfyapi(cls, reference_images[key])
|
||||
references.append(QuiverImageObject(url=url))
|
||||
if len(references) > 4:
|
||||
raise ValueError("Maximum 4 reference images are allowed.")
|
||||
|
||||
instructions_val = instructions.strip() if instructions else None
|
||||
if instructions_val == "":
|
||||
instructions_val = None
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
|
||||
response_model=QuiverSVGResponse,
|
||||
data=QuiverTextToSVGRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
instructions=instructions_val,
|
||||
references=references,
|
||||
temperature=model.get("temperature"),
|
||||
top_p=model.get("top_p"),
|
||||
presence_penalty=model.get("presence_penalty"),
|
||||
),
|
||||
)
|
||||
|
||||
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class QuiverImageToSVGNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="QuiverImageToSVGNode",
|
||||
display_name="Quiver Image to SVG",
|
||||
category="api node/image/Quiver",
|
||||
description="Vectorize a raster image into SVG using Quiver AI.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Input image to vectorize.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_crop",
|
||||
default=False,
|
||||
tooltip="Automatically crop to the dominant subject.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"arrow-preview",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"target_size",
|
||||
default=1024,
|
||||
min=128,
|
||||
max=4096,
|
||||
tooltip="Square resize target in pixels.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Randomness control. Higher values increase randomness.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=1.0,
|
||||
min=0.05,
|
||||
max=1.0,
|
||||
step=0.05,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Nucleus sampling parameter.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"presence_penalty",
|
||||
default=0.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Token presence penalty.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for SVG vectorization.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.429}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image,
|
||||
auto_crop: bool,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_url = await upload_image_to_comfyapi(cls, image)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
|
||||
response_model=QuiverSVGResponse,
|
||||
data=QuiverImageToSVGRequest(
|
||||
model=model["model"],
|
||||
image=QuiverImageObject(url=image_url),
|
||||
auto_crop=auto_crop if auto_crop else None,
|
||||
target_size=model.get("target_size"),
|
||||
temperature=model.get("temperature"),
|
||||
top_p=model.get("top_p"),
|
||||
presence_penalty=model.get("presence_penalty"),
|
||||
),
|
||||
)
|
||||
|
||||
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class QuiverExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
QuiverTextToSVGNode,
|
||||
QuiverImageToSVGNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> QuiverExtension:
|
||||
return QuiverExtension()
|
||||
@@ -3,6 +3,7 @@ from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import torch
|
||||
|
||||
|
||||
class Canny(io.ComfyNode):
|
||||
@@ -29,8 +30,8 @@ class Canny(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
|
||||
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
||||
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
|
||||
img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
return io.NodeOutput(img_out)
|
||||
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
|
||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.input import CurveInput
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CurveEditor(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CurveEditor",
|
||||
display_name="Curve Editor",
|
||||
category="utils",
|
||||
inputs=[
|
||||
io.Curve.Input("curve"),
|
||||
io.Histogram.Input("histogram", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Curve.Output("curve"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||
result = CurveInput.from_raw(curve)
|
||||
|
||||
ui = {}
|
||||
if histogram is not None:
|
||||
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||
|
||||
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||
|
||||
|
||||
class CurveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [CurveEditor]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return CurveExtension()
|
||||
@@ -3,6 +3,7 @@ import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import math
|
||||
import numpy as np
|
||||
@@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
return io.NodeOutput(video_latent, audio_latent)
|
||||
|
||||
|
||||
class LTXVReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="LTXVReferenceAudio",
|
||||
display_name="LTXV Reference Audio (ID-LoRA)",
|
||||
category="conditioning/audio",
|
||||
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
|
||||
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
|
||||
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
|
||||
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
|
||||
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||
# Encode reference audio to latents and patchify
|
||||
audio_latents = audio_vae.encode(reference_audio)
|
||||
b, c, t, f = audio_latents.shape
|
||||
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||
ref_audio = {"tokens": ref_tokens}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
|
||||
|
||||
# Patch model with identity guidance
|
||||
m = model.clone()
|
||||
scale = identity_guidance_scale
|
||||
model_sampling = m.get_model_object("model_sampling")
|
||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def post_cfg_function(args):
|
||||
if scale == 0:
|
||||
return args["denoised"]
|
||||
|
||||
sigma = args["sigma"]
|
||||
sigma_ = sigma[0].item()
|
||||
if sigma_ > sigma_start or sigma_ < sigma_end:
|
||||
return args["denoised"]
|
||||
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
# Strip ref_audio from conditioning for the no-reference pass
|
||||
noref_cond = []
|
||||
for entry in cond:
|
||||
new_entry = entry.copy()
|
||||
mc = new_entry.get("model_conds", {}).copy()
|
||||
mc.pop("ref_audio", None)
|
||||
new_entry["model_conds"] = mc
|
||||
noref_cond.append(new_entry)
|
||||
|
||||
(pred_noref,) = comfy.samplers.calc_cond_batch(
|
||||
args["model"], [noref_cond], x, sigma, model_options
|
||||
)
|
||||
|
||||
return cfg_result + (cond_pred - pred_noref) * scale
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return io.NodeOutput(m, positive, negative)
|
||||
|
||||
|
||||
class LtxvExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
|
||||
LTXVCropGuides,
|
||||
LTXVConcatAVLatent,
|
||||
LTXVSeparateAVLatent,
|
||||
LTXVReferenceAudio,
|
||||
]
|
||||
|
||||
|
||||
|
||||
79
comfy_extras/nodes_number_convert.py
Normal file
79
comfy_extras/nodes_number_convert.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Number Convert node for unified numeric type conversion.
|
||||
|
||||
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||
inputs into FLOAT and INT outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class NumberConvertNode(io.ComfyNode):
|
||||
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ComfyNumberConvert",
|
||||
display_name="Number Convert",
|
||||
category="math",
|
||||
search_aliases=[
|
||||
"int to float", "float to int", "number convert",
|
||||
"int2float", "float2int", "cast", "parse number",
|
||||
"string to number", "bool to int",
|
||||
],
|
||||
inputs=[
|
||||
io.MultiType.Input(
|
||||
"value",
|
||||
[io.Int, io.Float, io.String, io.Boolean],
|
||||
display_name="value",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Float.Output(display_name="FLOAT"),
|
||||
io.Int.Output(display_name="INT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, value) -> io.NodeOutput:
|
||||
if isinstance(value, bool):
|
||||
float_val = 1.0 if value else 0.0
|
||||
elif isinstance(value, (int, float)):
|
||||
float_val = float(value)
|
||||
elif isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
raise ValueError("Cannot convert empty string to number.")
|
||||
try:
|
||||
float_val = float(text)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Cannot convert string to number: {value!r}"
|
||||
) from None
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported input type: {type(value).__name__}"
|
||||
)
|
||||
|
||||
if not math.isfinite(float_val):
|
||||
raise ValueError(
|
||||
f"Cannot convert non-finite value to number: {float_val}"
|
||||
)
|
||||
|
||||
return io.NodeOutput(float_val, int(float_val))
|
||||
|
||||
|
||||
class NumberConvertExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [NumberConvertNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||
return NumberConvertExtension()
|
||||
@@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
@@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
training_dtype=torch.bfloat16,
|
||||
real_dataset=None,
|
||||
bucket_latents=None,
|
||||
use_grad_scaler=False,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
@@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.bucket_latents: list[torch.Tensor] | None = (
|
||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||
)
|
||||
# GradScaler for fp16 training
|
||||
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||
# Precompute bucket offsets and weights for sampling
|
||||
if bucket_latents is not None:
|
||||
self._init_bucket_data(bucket_latents)
|
||||
@@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
batch_sigmas.requires_grad_(True),
|
||||
**batch_extra_args,
|
||||
)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||
if bwd:
|
||||
bwd_loss = loss / self.grad_acc
|
||||
bwd_loss.backward()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(bwd_loss).backward()
|
||||
else:
|
||||
bwd_loss.backward()
|
||||
return loss
|
||||
|
||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||
@@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(total_loss).backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
@@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.unscale_(self.optimizer)
|
||||
for param_groups in self.optimizer.param_groups:
|
||||
for param in param_groups["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||
self.optimizer.step()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
ui_pbar.update(1)
|
||||
torch.cuda.empty_cache()
|
||||
@@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
),
|
||||
io.Combo.Input(
|
||||
"training_dtype",
|
||||
options=["bf16", "fp32"],
|
||||
options=["bf16", "fp32", "none"],
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for training.",
|
||||
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"lora_dtype",
|
||||
@@ -1014,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for lora.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"quantized_backward",
|
||||
default=False,
|
||||
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"algorithm",
|
||||
options=list(adapter_maps.keys()),
|
||||
@@ -1035,7 +1056,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||
tooltip="Offload model weights to CPU during training to save GPU memory.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
@@ -1081,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
quantized_backward,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
checkpoint_depth,
|
||||
@@ -1101,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed = seed[0]
|
||||
training_dtype = training_dtype[0]
|
||||
lora_dtype = lora_dtype[0]
|
||||
quantized_backward = quantized_backward[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
offloading = offloading[0]
|
||||
@@ -1109,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
|
||||
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||
|
||||
# Process latents based on mode
|
||||
if bucket_mode:
|
||||
latents = _process_latents_bucket_mode(latents)
|
||||
@@ -1120,22 +1145,35 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
if mp.is_dynamic():
|
||||
if not bypass_mode:
|
||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||
bypass_mode = True
|
||||
offloading = True
|
||||
elif offloading:
|
||||
if not bypass_mode:
|
||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, dtype, bucket_mode
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
@@ -1201,6 +1239,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
bucket_latents=latents,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
else:
|
||||
train_sampler = TrainSampler(
|
||||
@@ -1213,6 +1252,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
@@ -1337,7 +1377,7 @@ class SaveLoRA(io.ComfyNode):
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
optional=True,
|
||||
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.17.0"
|
||||
__version__ = "0.18.1"
|
||||
|
||||
59
main.py
59
main.py
@@ -9,6 +9,8 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.services import register_output_files
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
@@ -137,7 +139,16 @@ def execute_prestartup_script():
|
||||
spec.loader.exec_module(module)
|
||||
return True
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
from nodes import NODE_STARTUP_ERRORS, get_module_name
|
||||
node_module_name = get_module_name(os.path.dirname(script_path))
|
||||
NODE_STARTUP_ERRORS[node_module_name] = {
|
||||
"module_path": os.path.dirname(script_path),
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"phase": "prestartup",
|
||||
}
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
@@ -192,7 +203,6 @@ if 'torch' in sys.modules:
|
||||
|
||||
|
||||
import comfy.utils
|
||||
from app.assets.seeder import asset_seeder
|
||||
|
||||
import execution
|
||||
import server
|
||||
@@ -206,8 +216,8 @@ import hook_breaker_ac10a0
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
|
||||
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
|
||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if args.verbose == 'DEBUG':
|
||||
@@ -240,6 +250,38 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||
"""Extract absolute file paths for output items from a history result."""
|
||||
paths: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for node_output in history_result.get("outputs", {}).values():
|
||||
for items in node_output.values():
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type not in ("output", "temp"):
|
||||
continue
|
||||
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||
if base_dir is None:
|
||||
continue
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
filename = item.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
abs_path = os.path.abspath(
|
||||
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||
)
|
||||
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||
continue
|
||||
if abs_path not in seen:
|
||||
seen.add(abs_path)
|
||||
paths.append(abs_path)
|
||||
return paths
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
@@ -274,6 +316,7 @@ def prompt_worker(q, server_instance):
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
@@ -296,6 +339,10 @@ def prompt_worker(q, server_instance):
|
||||
else:
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
paths = _collect_output_absolute_paths(e.history_result)
|
||||
register_output_files(paths, job_id=prompt_id)
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
@@ -317,6 +364,9 @@ def prompt_worker(q, server_instance):
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
if not asset_seeder.is_disabled():
|
||||
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
@@ -471,6 +521,9 @@ if __name__ == "__main__":
|
||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||
|
||||
if args.disable_dynamic_vram:
|
||||
logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.")
|
||||
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
try:
|
||||
x = start_all_func()
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b5
|
||||
comfyui_manager==4.1b8
|
||||
|
||||
22
nodes.py
22
nodes.py
@@ -952,7 +952,7 @@ class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
@@ -1966,9 +1966,11 @@ class EmptyImage:
|
||||
CATEGORY = "image"
|
||||
|
||||
def generate(self, width, height, batch_size=1, color=0):
|
||||
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
|
||||
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
|
||||
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
device = comfy.model_management.intermediate_device()
|
||||
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF, device=device, dtype=dtype)
|
||||
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF, device=device, dtype=dtype)
|
||||
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF, device=device, dtype=dtype)
|
||||
return (torch.cat((r, g, b), dim=-1), )
|
||||
|
||||
class ImagePadForOutpaint:
|
||||
@@ -2179,6 +2181,9 @@ EXTENSION_WEB_DIRS = {}
|
||||
# Dictionary of successfully loaded module names and associated directories.
|
||||
LOADED_MODULE_DIRS = {}
|
||||
|
||||
# Dictionary of custom node startup errors, keyed by module name.
|
||||
NODE_STARTUP_ERRORS: dict[str, dict] = {}
|
||||
|
||||
|
||||
def get_module_name(module_path: str) -> str:
|
||||
"""
|
||||
@@ -2296,6 +2301,13 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
logging.warning(f"Cannot import {module_path} module for custom nodes: {e}")
|
||||
module_name = get_module_name(module_path)
|
||||
NODE_STARTUP_ERRORS[module_name] = {
|
||||
"module_path": module_path,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"phase": "import",
|
||||
}
|
||||
return False
|
||||
|
||||
async def init_external_custom_nodes():
|
||||
@@ -2452,7 +2464,9 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.17.0"
|
||||
version = "0.18.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-workflow-templates==0.9.21
|
||||
comfyui-frontend-package==1.42.8
|
||||
comfyui-workflow-templates==0.9.36
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
83
server.py
83
server.py
@@ -35,6 +35,8 @@ from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@@ -419,7 +421,24 @@ class PromptServer():
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||
resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
|
||||
|
||||
if args.enable_assets:
|
||||
try:
|
||||
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
||||
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
||||
resp["asset"] = {
|
||||
"id": result.ref.id,
|
||||
"name": result.ref.name,
|
||||
"asset_hash": result.asset.hash,
|
||||
"size": result.asset.size_bytes,
|
||||
"mime_type": result.asset.mime_type,
|
||||
"tags": result.tags,
|
||||
}
|
||||
except Exception:
|
||||
logging.warning("Failed to register uploaded image as asset", exc_info=True)
|
||||
|
||||
return web.json_response(resp)
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@@ -479,30 +498,43 @@ class PromptServer():
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
# The frontend's LoadImage combo widget uses asset_hash values
|
||||
# (e.g. "blake3:...") as widget values. When litegraph renders the
|
||||
# node preview, it constructs /view?filename=<asset_hash>, so this
|
||||
# endpoint must resolve blake3 hashes to their on-disk file paths.
|
||||
if filename.startswith("blake3:"):
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
result = resolve_hash_to_path(filename, owner_id=owner_id)
|
||||
if result is None:
|
||||
return web.Response(status=404)
|
||||
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
|
||||
else:
|
||||
resolved_content_type = None
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
|
||||
if "subfolder" in request.rel_url.query:
|
||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
if "subfolder" in request.rel_url.query:
|
||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
|
||||
if os.path.isfile(file):
|
||||
if 'preview' in request.rel_url.query:
|
||||
@@ -562,8 +594,13 @@ class PromptServer():
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
# Use the content type from asset resolution if available,
|
||||
# otherwise guess from the filename.
|
||||
content_type = (
|
||||
resolved_content_type
|
||||
or mimetypes.guess_type(filename)[0]
|
||||
or 'application/octet-stream'
|
||||
)
|
||||
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
@@ -716,6 +753,10 @@ class PromptServer():
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/custom_node_startup_errors")
|
||||
async def get_custom_node_startup_errors(request):
|
||||
return web.json_response(nodes.NODE_STARTUP_ERRORS)
|
||||
|
||||
@routes.get("/api/jobs")
|
||||
async def get_jobs(request):
|
||||
"""List all jobs with filtering, sorting, and pagination.
|
||||
|
||||
57
tests-unit/app_test/test_migrations.py
Normal file
57
tests-unit/app_test/test_migrations.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
|
||||
|
||||
This catches problems like unnamed FK constraints that prevent batch-mode
|
||||
drop_constraint from working on real SQLite files (see MB-2).
|
||||
|
||||
Migrations 0001 and 0002 are already shipped, so we only exercise
|
||||
upgrade/downgrade for 0003+.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
|
||||
# Oldest shipped revision — we upgrade to here as a baseline and never
|
||||
# downgrade past it.
|
||||
_BASELINE = "0002_merge_to_asset_references"
|
||||
|
||||
|
||||
def _make_config(db_path: str) -> Config:
|
||||
root = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
|
||||
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_db(tmp_path):
|
||||
"""Yield an alembic Config pre-upgraded to the baseline revision."""
|
||||
db_path = str(tmp_path / "test_migration.db")
|
||||
cfg = _make_config(db_path)
|
||||
command.upgrade(cfg, _BASELINE)
|
||||
yield cfg
|
||||
|
||||
|
||||
def test_upgrade_to_head(migration_db):
|
||||
"""Upgrade from baseline to head must succeed on a file-backed DB."""
|
||||
command.upgrade(migration_db, "head")
|
||||
|
||||
|
||||
def test_downgrade_to_baseline(migration_db):
|
||||
"""Upgrade to head then downgrade back to baseline."""
|
||||
command.upgrade(migration_db, "head")
|
||||
command.downgrade(migration_db, _BASELINE)
|
||||
|
||||
|
||||
def test_upgrade_downgrade_cycle(migration_db):
|
||||
"""Full cycle: upgrade → downgrade → upgrade again."""
|
||||
command.upgrade(migration_db, "head")
|
||||
command.downgrade(migration_db, _BASELINE)
|
||||
command.upgrade(migration_db, "head")
|
||||
@@ -10,6 +10,7 @@ from app.assets.database.queries import (
|
||||
get_asset_by_hash,
|
||||
upsert_asset,
|
||||
bulk_insert_assets,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
|
||||
|
||||
@@ -142,3 +143,45 @@ class TestBulkInsertAssets:
|
||||
session.commit()
|
||||
|
||||
assert session.query(Asset).count() == 200
|
||||
|
||||
|
||||
class TestMimeTypeImmutability:
|
||||
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"initial_mime,second_mime,expected_mime",
|
||||
[
|
||||
("image/png", "image/jpeg", "image/png"),
|
||||
(None, "image/png", "image/png"),
|
||||
],
|
||||
ids=["preserves_existing", "fills_null"],
|
||||
)
|
||||
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
|
||||
h = f"blake3:upsert_{initial_mime}_{second_mime}"
|
||||
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
|
||||
session.commit()
|
||||
|
||||
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
|
||||
assert created is False
|
||||
assert asset.mime_type == expected_mime
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
|
||||
[
|
||||
(None, "image/png", None, "image/png", "blake3:upd0"),
|
||||
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
|
||||
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
|
||||
],
|
||||
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
|
||||
)
|
||||
def test_update_asset_hash_and_mime_immutability(
|
||||
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
|
||||
):
|
||||
h = expected_hash.removesuffix("_new")
|
||||
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
|
||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
|
||||
assert asset.mime_type == expected_mime
|
||||
assert asset.hash == expected_hash
|
||||
|
||||
@@ -242,22 +242,24 @@ class TestSetReferencePreview:
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
session.commit()
|
||||
|
||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id)
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.preview_id == preview_asset.id
|
||||
assert ref.preview_id == preview_ref.id
|
||||
|
||||
def test_clears_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
ref.preview_id = preview_asset.id
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
ref.preview_id = preview_ref.id
|
||||
session.commit()
|
||||
|
||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None)
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
@@ -265,15 +267,15 @@ class TestSetReferencePreview:
|
||||
|
||||
def test_raises_for_nonexistent_reference(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None)
|
||||
set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
|
||||
|
||||
def test_raises_for_nonexistent_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(ValueError, match="Preview Asset"):
|
||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent")
|
||||
with pytest.raises(ValueError, match="Preview AssetReference"):
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
|
||||
|
||||
|
||||
class TestInsertReference:
|
||||
@@ -351,13 +353,14 @@ class TestUpdateReferenceTimestamps:
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
session.commit()
|
||||
|
||||
update_reference_timestamps(session, ref, preview_id=preview_asset.id)
|
||||
update_reference_timestamps(session, ref, preview_id=preview_ref.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.preview_id == preview_asset.id
|
||||
assert ref.preview_id == preview_ref.id
|
||||
|
||||
|
||||
class TestSetReferenceMetadata:
|
||||
|
||||
@@ -20,6 +20,7 @@ def _make_reference(
|
||||
asset: Asset,
|
||||
name: str,
|
||||
metadata: dict | None = None,
|
||||
system_metadata: dict | None = None,
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
@@ -27,6 +28,7 @@ def _make_reference(
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
user_metadata=metadata,
|
||||
system_metadata=system_metadata,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
@@ -34,8 +36,10 @@ def _make_reference(
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
|
||||
if metadata:
|
||||
for key, val in metadata.items():
|
||||
# Build merged projection: {**system_metadata, **user_metadata}
|
||||
merged = {**(system_metadata or {}), **(metadata or {})}
|
||||
if merged:
|
||||
for key, val in merged.items():
|
||||
for row in convert_metadata_to_rows(key, val):
|
||||
meta_row = AssetReferenceMeta(
|
||||
asset_reference_id=ref.id,
|
||||
@@ -182,3 +186,46 @@ class TestMetadataFilterEmptyDict:
|
||||
|
||||
refs, _, total = list_references_page(session, metadata_filter={})
|
||||
assert total == 2
|
||||
|
||||
|
||||
class TestSystemMetadataProjection:
|
||||
"""Tests for system_metadata merging into the filter projection."""
|
||||
|
||||
def test_system_metadata_keys_are_filterable(self, session: Session):
|
||||
"""system_metadata keys should appear in the merged projection."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(
|
||||
session, asset, "with_sys",
|
||||
system_metadata={"source": "scanner"},
|
||||
)
|
||||
_make_reference(session, asset, "without_sys")
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"source": "scanner"}
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].name == "with_sys"
|
||||
|
||||
def test_user_metadata_overrides_system_metadata(self, session: Session):
|
||||
"""user_metadata should win when both have the same key."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(
|
||||
session, asset, "overridden",
|
||||
metadata={"origin": "user_upload"},
|
||||
system_metadata={"origin": "auto_scan"},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Should match the user value, not the system value
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"origin": "user_upload"}
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].name == "overridden"
|
||||
|
||||
# Should NOT match the system value (it was overridden)
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"origin": "auto_scan"}
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Base
|
||||
@@ -23,6 +23,21 @@ def db_engine():
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine_fk():
|
||||
"""In-memory SQLite engine with foreign key enforcement enabled."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def _set_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(db_engine):
|
||||
"""Session fixture for tests that need direct DB access."""
|
||||
|
||||
@@ -11,6 +11,7 @@ from app.assets.services import (
|
||||
delete_asset_reference,
|
||||
set_asset_preview,
|
||||
)
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
||||
@@ -219,31 +220,33 @@ class TestSetAssetPreview:
|
||||
asset = _make_asset(session, hash_val="blake3:main")
|
||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
ref_id = ref.id
|
||||
preview_id = preview_asset.id
|
||||
preview_ref_id = preview_ref.id
|
||||
session.commit()
|
||||
|
||||
set_asset_preview(
|
||||
reference_id=ref_id,
|
||||
preview_asset_id=preview_id,
|
||||
preview_reference_id=preview_ref_id,
|
||||
)
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
session.expire_all()
|
||||
updated_ref = session.get(AssetReference, ref_id)
|
||||
assert updated_ref.preview_id == preview_id
|
||||
assert updated_ref.preview_id == preview_ref_id
|
||||
|
||||
def test_clears_preview(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||
ref = _make_reference(session, asset)
|
||||
ref.preview_id = preview_asset.id
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
ref.preview_id = preview_ref.id
|
||||
ref_id = ref.id
|
||||
session.commit()
|
||||
|
||||
set_asset_preview(
|
||||
reference_id=ref_id,
|
||||
preview_asset_id=None,
|
||||
preview_reference_id=None,
|
||||
)
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
@@ -263,6 +266,45 @@ class TestSetAssetPreview:
|
||||
with pytest.raises(PermissionError, match="not owner"):
|
||||
set_asset_preview(
|
||||
reference_id=ref.id,
|
||||
preview_asset_id=None,
|
||||
preview_reference_id=None,
|
||||
owner_id="user2",
|
||||
)
|
||||
|
||||
|
||||
class TestResolveHashToPath:
|
||||
def test_returns_none_for_unknown_hash(self, mock_create_session):
|
||||
result = resolve_hash_to_path("blake3:" + "a" * 64)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ref_owner, query_owner, expect_found",
|
||||
[
|
||||
("user1", "user1", True),
|
||||
("user1", "user2", False),
|
||||
("", "anyone", True),
|
||||
("", "", True),
|
||||
],
|
||||
ids=[
|
||||
"owner_sees_own_ref",
|
||||
"other_owner_blocked",
|
||||
"ownerless_visible_to_anyone",
|
||||
"ownerless_visible_to_empty",
|
||||
],
|
||||
)
|
||||
def test_owner_visibility(
|
||||
self, ref_owner, query_owner, expect_found,
|
||||
mock_create_session, session: Session, temp_dir,
|
||||
):
|
||||
f = temp_dir / "file.bin"
|
||||
f.write_bytes(b"data")
|
||||
asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
|
||||
ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
|
||||
ref.file_path = str(f)
|
||||
session.commit()
|
||||
|
||||
result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
|
||||
if expect_found:
|
||||
assert result is not None
|
||||
assert result.abs_path == str(f)
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for asset enrichment (mime_type and hash population)."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.services.file_utils import get_mtime_ns
|
||||
from app.assets.scanner import (
|
||||
ENRICHMENT_HASHED,
|
||||
ENRICHMENT_METADATA,
|
||||
@@ -20,6 +22,13 @@ def _create_stub_asset(
|
||||
name: str | None = None,
|
||||
) -> tuple[Asset, AssetReference]:
|
||||
"""Create a stub asset with reference for testing enrichment."""
|
||||
# Use the real file's mtime so the optimistic guard in enrich_asset passes
|
||||
try:
|
||||
stat_result = os.stat(file_path, follow_symlinks=True)
|
||||
mtime_ns = get_mtime_ns(stat_result)
|
||||
except OSError:
|
||||
mtime_ns = 1234567890000000000
|
||||
|
||||
asset = Asset(
|
||||
id=asset_id,
|
||||
hash=None,
|
||||
@@ -35,7 +44,7 @@ def _create_stub_asset(
|
||||
name=name or f"test-asset-{asset_id}",
|
||||
owner_id="system",
|
||||
file_path=file_path,
|
||||
mtime_ns=1234567890000000000,
|
||||
mtime_ns=mtime_ns,
|
||||
enrichment_level=ENRICHMENT_STUB,
|
||||
)
|
||||
session.add(ref)
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""Tests for ingest services."""
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session as SASession, Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, Tag
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||
from app.assets.database.queries import get_reference_tags
|
||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
||||
from app.assets.services.ingest import (
|
||||
_ingest_file_from_path,
|
||||
_register_existing_asset,
|
||||
ingest_existing_file,
|
||||
)
|
||||
|
||||
|
||||
class TestIngestFileFromPath:
|
||||
@@ -113,11 +119,19 @@ class TestIngestFileFromPath:
|
||||
file_path = temp_dir / "with_preview.bin"
|
||||
file_path.write_bytes(b"data")
|
||||
|
||||
# Create a preview asset first
|
||||
# Create a preview asset and reference
|
||||
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
||||
session.add(preview_asset)
|
||||
session.flush()
|
||||
from app.assets.helpers import get_utc_now
|
||||
now = get_utc_now()
|
||||
preview_ref = AssetReference(
|
||||
asset_id=preview_asset.id, name="preview.png", owner_id="",
|
||||
created_at=now, updated_at=now, last_access_time=now,
|
||||
)
|
||||
session.add(preview_ref)
|
||||
session.commit()
|
||||
preview_id = preview_asset.id
|
||||
preview_id = preview_ref.id
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
@@ -227,3 +241,42 @@ class TestRegisterExistingAsset:
|
||||
|
||||
assert result.created is True
|
||||
assert set(result.tags) == {"alpha", "beta"}
|
||||
|
||||
|
||||
class TestIngestExistingFileTagFK:
|
||||
"""Regression: ingest_existing_file must seed Tag rows before inserting
|
||||
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
|
||||
|
||||
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
|
||||
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
|
||||
before they can be referenced in asset_reference_tags."""
|
||||
|
||||
@contextmanager
|
||||
def _create_session():
|
||||
with SASession(db_engine_fk) as sess:
|
||||
yield sess
|
||||
|
||||
file_path = temp_dir / "output.png"
|
||||
file_path.write_bytes(b"image data")
|
||||
|
||||
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||
patch(
|
||||
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
|
||||
return_value=("output.png", ["output"]),
|
||||
):
|
||||
result = ingest_existing_file(
|
||||
abs_path=str(file_path),
|
||||
extra_tags=["my-job"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
with SASession(db_engine_fk) as sess:
|
||||
tag_names = {t.name for t in sess.query(Tag).all()}
|
||||
assert "output" in tag_names
|
||||
assert "my-job" in tag_names
|
||||
|
||||
ref_tags = sess.query(AssetReferenceTag).all()
|
||||
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||
assert "output" in ref_tag_names
|
||||
assert "my-job" in ref_tag_names
|
||||
|
||||
123
tests-unit/assets_test/services/test_tag_histogram.py
Normal file
123
tests-unit/assets_test/services/test_tag_histogram.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Tests for list_tag_histogram service function."""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_reference(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str = "test",
|
||||
owner_id: str = "",
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
return ref
|
||||
|
||||
|
||||
class TestListTagHistogram:
|
||||
def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
a1 = _make_asset(session, "blake3:aaa")
|
||||
r1 = _make_reference(session, a1, name="r1")
|
||||
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
|
||||
|
||||
a2 = _make_asset(session, "blake3:bbb")
|
||||
r2 = _make_reference(session, a2, name="r2")
|
||||
add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram()
|
||||
|
||||
assert result["alpha"] == 2
|
||||
assert result["beta"] == 1
|
||||
|
||||
def test_empty_when_no_assets(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["unused"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram()
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_include_tags_filter(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["models", "loras", "input"])
|
||||
a1 = _make_asset(session, "blake3:aaa")
|
||||
r1 = _make_reference(session, a1, name="r1")
|
||||
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
||||
|
||||
a2 = _make_asset(session, "blake3:bbb")
|
||||
r2 = _make_reference(session, a2, name="r2")
|
||||
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram(include_tags=["models"])
|
||||
|
||||
# Only r1 has "models", so only its tags appear
|
||||
assert "models" in result
|
||||
assert "loras" in result
|
||||
assert "input" not in result
|
||||
|
||||
def test_exclude_tags_filter(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["models", "loras", "input"])
|
||||
a1 = _make_asset(session, "blake3:aaa")
|
||||
r1 = _make_reference(session, a1, name="r1")
|
||||
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
||||
|
||||
a2 = _make_asset(session, "blake3:bbb")
|
||||
r2 = _make_reference(session, a2, name="r2")
|
||||
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram(exclude_tags=["models"])
|
||||
|
||||
# r1 excluded, only r2's tags remain
|
||||
assert "input" in result
|
||||
assert "loras" not in result
|
||||
|
||||
def test_name_contains_filter(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
a1 = _make_asset(session, "blake3:aaa")
|
||||
r1 = _make_reference(session, a1, name="my_model.safetensors")
|
||||
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
|
||||
|
||||
a2 = _make_asset(session, "blake3:bbb")
|
||||
r2 = _make_reference(session, a2, name="picture.png")
|
||||
add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram(name_contains="model")
|
||||
|
||||
assert "alpha" in result
|
||||
assert "beta" not in result
|
||||
|
||||
def test_limit_caps_results(self, mock_create_session, session: Session):
|
||||
tags = [f"tag{i}" for i in range(10)]
|
||||
ensure_tags_exist(session, tags)
|
||||
a = _make_asset(session, "blake3:aaa")
|
||||
r = _make_reference(session, a, name="r1")
|
||||
add_tags_to_reference(session, reference_id=r.id, tags=tags)
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram(limit=3)
|
||||
|
||||
assert len(result) == 3
|
||||
@@ -243,6 +243,15 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
|
||||
|
||||
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||
root: str,
|
||||
|
||||
123
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
123
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_nodes = MagicMock()
|
||||
mock_nodes.MAX_RESOLUTION = 16384
|
||||
mock_server = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||
|
||||
|
||||
class TestNumberConvertExecute:
|
||||
@staticmethod
|
||||
def _exec(value) -> object:
|
||||
return NumberConvertNode.execute(value)
|
||||
|
||||
# --- INT input ---
|
||||
|
||||
def test_int_input(self):
|
||||
result = self._exec(42)
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_int_zero(self):
|
||||
result = self._exec(0)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
def test_int_negative(self):
|
||||
result = self._exec(-7)
|
||||
assert result[0] == -7.0
|
||||
assert result[1] == -7
|
||||
|
||||
# --- FLOAT input ---
|
||||
|
||||
def test_float_input(self):
|
||||
result = self._exec(3.14)
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_float_truncation_toward_zero(self):
|
||||
result = self._exec(-2.9)
|
||||
assert result[0] == -2.9
|
||||
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||
|
||||
def test_float_output_type(self):
|
||||
result = self._exec(5)
|
||||
assert isinstance(result[0], float)
|
||||
|
||||
def test_int_output_type(self):
|
||||
result = self._exec(5.7)
|
||||
assert isinstance(result[1], int)
|
||||
|
||||
# --- BOOL input ---
|
||||
|
||||
def test_bool_true(self):
|
||||
result = self._exec(True)
|
||||
assert result[0] == 1.0
|
||||
assert result[1] == 1
|
||||
|
||||
def test_bool_false(self):
|
||||
result = self._exec(False)
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0
|
||||
|
||||
# --- STRING input ---
|
||||
|
||||
def test_string_integer(self):
|
||||
result = self._exec("42")
|
||||
assert result[0] == 42.0
|
||||
assert result[1] == 42
|
||||
|
||||
def test_string_float(self):
|
||||
result = self._exec("3.14")
|
||||
assert result[0] == 3.14
|
||||
assert result[1] == 3
|
||||
|
||||
def test_string_negative(self):
|
||||
result = self._exec("-5.5")
|
||||
assert result[0] == -5.5
|
||||
assert result[1] == -5
|
||||
|
||||
def test_string_with_whitespace(self):
|
||||
result = self._exec(" 7.0 ")
|
||||
assert result[0] == 7.0
|
||||
assert result[1] == 7
|
||||
|
||||
def test_string_scientific_notation(self):
|
||||
result = self._exec("1e3")
|
||||
assert result[0] == 1000.0
|
||||
assert result[1] == 1000
|
||||
|
||||
# --- STRING error paths ---
|
||||
|
||||
def test_empty_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec("")
|
||||
|
||||
def test_whitespace_only_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||
self._exec(" ")
|
||||
|
||||
def test_non_numeric_string_raises(self):
|
||||
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||
self._exec("abc")
|
||||
|
||||
def test_string_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("inf")
|
||||
|
||||
def test_string_nan_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("nan")
|
||||
|
||||
def test_string_negative_inf_raises(self):
|
||||
with pytest.raises(ValueError, match="non-finite"):
|
||||
self._exec("-inf")
|
||||
|
||||
# --- Unsupported type ---
|
||||
|
||||
def test_unsupported_type_raises(self):
|
||||
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||
self._exec([1, 2, 3])
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -771,6 +772,188 @@ class TestSeederStopRestart:
|
||||
assert collected_roots[1] == ("input",)
|
||||
|
||||
|
||||
class TestEnqueueEnrichHandoff:
|
||||
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
|
||||
|
||||
def test_pending_enrich_runs_after_scan_completes(
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
"""A queued enrich request runs automatically when a scan finishes."""
|
||||
enrich_roots_seen: list[tuple] = []
|
||||
original_start = fresh_seeder.start
|
||||
|
||||
def tracking_start(*args, **kwargs):
|
||||
phase = kwargs.get("phase")
|
||||
roots = kwargs.get("roots", args[0] if args else None)
|
||||
result = original_start(*args, **kwargs)
|
||||
if phase == ScanPhase.ENRICH and result:
|
||||
enrich_roots_seen.append(roots)
|
||||
return result
|
||||
|
||||
fresh_seeder.start = tracking_start
|
||||
|
||||
# Start a fast scan, then enqueue an enrich while it's running
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
|
||||
def slow_collect(*args):
|
||||
reached.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert reached.wait(timeout=2.0)
|
||||
|
||||
queued = fresh_seeder.enqueue_enrich(
|
||||
roots=("input",), compute_hashes=True
|
||||
)
|
||||
assert queued is False # queued, not started immediately
|
||||
|
||||
barrier.set()
|
||||
|
||||
# Wait for the original scan + the auto-started enrich scan
|
||||
deadline = time.monotonic() + 5.0
|
||||
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||
time.sleep(0.05)
|
||||
|
||||
assert enrich_roots_seen == [("input",)]
|
||||
|
||||
def test_enqueue_enrich_during_drain_does_not_lose_work(
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
"""enqueue_enrich called concurrently with drain cannot drop work.
|
||||
|
||||
Simulates the race: another thread calls enqueue_enrich right as the
|
||||
scan thread is draining _pending_enrich. The enqueue must either be
|
||||
picked up by the draining scan or successfully start its own scan.
|
||||
"""
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
enrich_started = threading.Event()
|
||||
|
||||
enrich_call_count = 0
|
||||
|
||||
def slow_collect(*args):
|
||||
reached.set()
|
||||
barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
# Track how many times start_enrich actually fires
|
||||
real_start_enrich = fresh_seeder.start_enrich
|
||||
enrich_roots_seen: list[tuple] = []
|
||||
|
||||
def tracking_start_enrich(**kwargs):
|
||||
nonlocal enrich_call_count
|
||||
enrich_call_count += 1
|
||||
enrich_roots_seen.append(kwargs.get("roots"))
|
||||
result = real_start_enrich(**kwargs)
|
||||
if result:
|
||||
enrich_started.set()
|
||||
return result
|
||||
|
||||
fresh_seeder.start_enrich = tracking_start_enrich
|
||||
|
||||
with patch(
|
||||
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||
):
|
||||
# Start a scan
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert reached.wait(timeout=2.0)
|
||||
|
||||
# Queue an enrich while scan is running
|
||||
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
# Let scan finish — drain will fire start_enrich atomically
|
||||
barrier.set()
|
||||
|
||||
# Wait for drain to complete and the enrich scan to start
|
||||
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
|
||||
assert ("output",) in enrich_roots_seen
|
||||
|
||||
def test_concurrent_enqueue_during_drain_not_lost(
|
||||
self, fresh_seeder: _AssetSeeder,
|
||||
):
|
||||
"""A second enqueue_enrich arriving while drain is in progress is not lost.
|
||||
|
||||
Because the drain now holds _lock through the start_enrich call,
|
||||
a concurrent enqueue_enrich will block until start_enrich has
|
||||
transitioned state to RUNNING, then the enqueue will queue its
|
||||
payload as _pending_enrich for the *next* drain.
|
||||
"""
|
||||
scan_barrier = threading.Event()
|
||||
scan_reached = threading.Event()
|
||||
enrich_barrier = threading.Event()
|
||||
enrich_reached = threading.Event()
|
||||
|
||||
collect_call = 0
|
||||
|
||||
def gated_collect(*args):
|
||||
nonlocal collect_call
|
||||
collect_call += 1
|
||||
if collect_call == 1:
|
||||
# First call: the initial fast scan
|
||||
scan_reached.set()
|
||||
scan_barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
enrich_call = 0
|
||||
|
||||
def gated_get_unenriched(*args, **kwargs):
|
||||
nonlocal enrich_call
|
||||
enrich_call += 1
|
||||
if enrich_call == 1:
|
||||
# First enrich batch: signal and block
|
||||
enrich_reached.set()
|
||||
enrich_barrier.wait(timeout=5.0)
|
||||
return []
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
):
|
||||
# 1. Start fast scan
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||
assert scan_reached.wait(timeout=2.0)
|
||||
|
||||
# 2. Queue enrich while fast scan is running
|
||||
queued = fresh_seeder.enqueue_enrich(
|
||||
roots=("input",), compute_hashes=False
|
||||
)
|
||||
assert queued is False
|
||||
|
||||
# 3. Let the fast scan finish — drain will start the enrich scan
|
||||
scan_barrier.set()
|
||||
|
||||
# 4. Wait until the drained enrich scan is running
|
||||
assert enrich_reached.wait(timeout=5.0)
|
||||
|
||||
# 5. Now enqueue another enrich while the drained scan is running
|
||||
queued2 = fresh_seeder.enqueue_enrich(
|
||||
roots=("output",), compute_hashes=True
|
||||
)
|
||||
assert queued2 is False # should be queued, not started
|
||||
|
||||
# Verify _pending_enrich was set (the second enqueue was captured)
|
||||
with fresh_seeder._lock:
|
||||
assert fresh_seeder._pending_enrich is not None
|
||||
assert "output" in fresh_seeder._pending_enrich["roots"]
|
||||
|
||||
# Let the enrich scan finish
|
||||
enrich_barrier.set()
|
||||
|
||||
deadline = time.monotonic() + 5.0
|
||||
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||
return UnenrichedReferenceRow(
|
||||
reference_id=ref_id, asset_id=asset_id,
|
||||
|
||||
250
tests/test_asset_seeder.py
Normal file
250
tests/test_asset_seeder.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def seeder():
|
||||
"""Fresh seeder instance for each test."""
|
||||
return _AssetSeeder()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reset_to_idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResetToIdle:
|
||||
def test_sets_idle_and_clears_progress(self, seeder):
|
||||
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||
seeder._state = State.RUNNING
|
||||
seeder._progress = progress
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is progress
|
||||
|
||||
def test_noop_when_progress_already_none(self, seeder):
|
||||
"""_reset_to_idle should handle None progress gracefully."""
|
||||
seeder._state = State.CANCELLING
|
||||
seeder._progress = None
|
||||
|
||||
with seeder._lock:
|
||||
seeder._reset_to_idle()
|
||||
|
||||
assert seeder._state is State.IDLE
|
||||
assert seeder._progress is None
|
||||
assert seeder._last_progress is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – immediate start when idle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichStartsImmediately:
|
||||
def test_starts_when_idle(self, seeder):
|
||||
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||
|
||||
def test_no_pending_when_started_immediately(self, seeder):
|
||||
"""No pending request should be stored when start_enrich succeeds."""
|
||||
with patch.object(seeder, "start_enrich", return_value=True):
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – queuing when busy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichQueuesWhenBusy:
|
||||
def test_queues_when_busy(self, seeder):
|
||||
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
|
||||
assert result is False
|
||||
assert seeder._pending_enrich == {
|
||||
"roots": ("models",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# enqueue_enrich – merging when a pending request already exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichMergesPending:
|
||||
def _make_busy(self, seeder):
|
||||
"""Patch start_enrich to always return False (seeder busy)."""
|
||||
return patch.object(seeder, "start_enrich", return_value=False)
|
||||
|
||||
def test_merges_roots(self, seeder):
|
||||
"""A second enqueue should merge roots with the existing pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",))
|
||||
seeder.enqueue_enrich(roots=("output",))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "output"}
|
||||
|
||||
def test_merges_overlapping_roots(self, seeder):
|
||||
"""Duplicate roots should be deduplicated."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models", "input"))
|
||||
seeder.enqueue_enrich(roots=("input", "output"))
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
|
||||
def test_compute_hashes_sticky_true(self, seeder):
|
||||
"""Once compute_hashes is True it should stay True after merging."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
def test_compute_hashes_stays_false(self, seeder):
|
||||
"""If both enqueues have compute_hashes=False it stays False."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||
|
||||
assert seeder._pending_enrich["compute_hashes"] is False
|
||||
|
||||
def test_triple_merge(self, seeder):
|
||||
"""Three successive enqueues should all merge correctly."""
|
||||
with self._make_busy(seeder):
|
||||
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
assert seeder._pending_enrich["compute_hashes"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending enrich drains after scan completes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPendingEnrichDrain:
|
||||
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||
"""After a fast scan finishes, the pending enrich should be started."""
|
||||
seeder = _AssetSeeder()
|
||||
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": True,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_called_once_with(
|
||||
roots=("output",),
|
||||
compute_hashes=True,
|
||||
)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||
seeder = _AssetSeeder()
|
||||
seeder._pending_enrich = {
|
||||
"roots": ("output",),
|
||||
"compute_hashes": False,
|
||||
}
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||
def test_no_drain_when_no_pending(self, *_mocks):
|
||||
"""start_enrich should not be called when there is no pending request."""
|
||||
seeder = _AssetSeeder()
|
||||
assert seeder._pending_enrich is None
|
||||
|
||||
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||
seeder.start_fast(roots=("models",))
|
||||
seeder.wait(timeout=5)
|
||||
|
||||
mock_start.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread-safety of enqueue_enrich
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnqueueEnrichThreadSafety:
|
||||
def test_concurrent_enqueues(self, seeder):
|
||||
"""Multiple threads enqueuing should not lose roots."""
|
||||
with patch.object(seeder, "start_enrich", return_value=False):
|
||||
barrier = threading.Barrier(3)
|
||||
|
||||
def enqueue(root):
|
||||
barrier.wait()
|
||||
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=enqueue, args=(r,))
|
||||
for r in ("models", "input", "output")
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
merged = set(seeder._pending_enrich["roots"])
|
||||
assert merged == {"models", "input", "output"}
|
||||
Reference in New Issue
Block a user