mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-18 03:27:30 +00:00
Compare commits
2 Commits
curve-node
...
fix/static
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
286e3e8ed4 | ||
|
|
a9ce45279e |
103
.github/scripts/check-ai-co-authors.sh
vendored
103
.github/scripts/check-ai-co-authors.sh
vendored
@@ -1,103 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Checks pull request commits for AI agent Co-authored-by trailers.
|
||||
# Exits non-zero when any are found and prints fix instructions.
|
||||
set -euo pipefail
|
||||
|
||||
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||
|
||||
# Known AI coding-agent trailer patterns (case-insensitive).
|
||||
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
|
||||
AGENT_PATTERNS=(
|
||||
# Anthropic — Claude Code / Amp
|
||||
'noreply@anthropic\.com'
|
||||
# Cursor
|
||||
'cursoragent@cursor\.com'
|
||||
# GitHub Copilot
|
||||
'copilot-swe-agent\[bot\]'
|
||||
'copilot@github\.com'
|
||||
# OpenAI Codex
|
||||
'noreply@openai\.com'
|
||||
'codex@openai\.com'
|
||||
# Aider
|
||||
'aider@aider\.chat'
|
||||
# Google — Gemini / Jules
|
||||
'gemini@google\.com'
|
||||
'jules@google\.com'
|
||||
# Windsurf / Codeium
|
||||
'@codeium\.com'
|
||||
# Devin
|
||||
'devin-ai-integration\[bot\]'
|
||||
'devin@cognition\.ai'
|
||||
'devin@cognition-labs\.com'
|
||||
# Amazon Q Developer
|
||||
'amazon-q-developer'
|
||||
'@amazon\.com.*[Qq].[Dd]eveloper'
|
||||
# Cline
|
||||
'cline-bot'
|
||||
'cline@cline\.ai'
|
||||
# Continue
|
||||
'continue-agent'
|
||||
'continue@continue\.dev'
|
||||
# Sourcegraph
|
||||
'noreply@sourcegraph\.com'
|
||||
# Generic catch-alls for common agent name patterns
|
||||
'Co-authored-by:.*\b[Cc]laude\b'
|
||||
'Co-authored-by:.*\b[Cc]opilot\b'
|
||||
'Co-authored-by:.*\b[Cc]ursor\b'
|
||||
'Co-authored-by:.*\b[Cc]odex\b'
|
||||
'Co-authored-by:.*\b[Gg]emini\b'
|
||||
'Co-authored-by:.*\b[Aa]ider\b'
|
||||
'Co-authored-by:.*\b[Dd]evin\b'
|
||||
'Co-authored-by:.*\b[Ww]indsurf\b'
|
||||
'Co-authored-by:.*\b[Cc]line\b'
|
||||
'Co-authored-by:.*\b[Aa]mazon Q\b'
|
||||
'Co-authored-by:.*\b[Jj]ules\b'
|
||||
'Co-authored-by:.*\bOpenCode\b'
|
||||
)
|
||||
|
||||
# Build a single alternation regex from all patterns.
|
||||
regex=""
|
||||
for pattern in "${AGENT_PATTERNS[@]}"; do
|
||||
if [[ -n "$regex" ]]; then
|
||||
regex="${regex}|${pattern}"
|
||||
else
|
||||
regex="$pattern"
|
||||
fi
|
||||
done
|
||||
|
||||
# Collect Co-authored-by lines from every commit in the PR range.
|
||||
violations=""
|
||||
while IFS= read -r sha; do
|
||||
message="$(git log -1 --format='%B' "$sha")"
|
||||
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
|
||||
if [[ -z "$matched_lines" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
while IFS= read -r line; do
|
||||
if echo "$line" | grep -iqE "$regex"; then
|
||||
short="$(git log -1 --format='%h' "$sha")"
|
||||
violations="${violations} ${short}: ${line}"$'\n'
|
||||
fi
|
||||
done <<< "$matched_lines"
|
||||
done < <(git rev-list "${base_sha}..${head_sha}")
|
||||
|
||||
if [[ -n "$violations" ]]; then
|
||||
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
|
||||
echo ""
|
||||
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
|
||||
echo ""
|
||||
echo "$violations"
|
||||
echo "These trailers should be removed before merging."
|
||||
echo ""
|
||||
echo "To fix, rewrite the commit messages with:"
|
||||
echo " git rebase -i ${base_sha}"
|
||||
echo ""
|
||||
echo "and remove the Co-authored-by lines, then force-push your branch."
|
||||
echo ""
|
||||
echo "If you believe this is a false positive, please open an issue."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No AI agent Co-authored-by trailers found."
|
||||
19
.github/workflows/check-ai-co-authors.yml
vendored
19
.github/workflows/check-ai-co-authors.yml
vendored
@@ -1,19 +0,0 @@
|
||||
name: Check AI Co-Authors
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*']
|
||||
|
||||
jobs:
|
||||
check-ai-co-authors:
|
||||
name: Check for AI agent co-author trailers
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check commits for AI co-author trailers
|
||||
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
|
||||
11
README.md
11
README.md
@@ -38,8 +38,6 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
|
||||
## Get Started
|
||||
|
||||
### Local
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
@@ -51,13 +49,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
#### [Manual Install](#manual-install-windows-linux)
|
||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||
|
||||
### Cloud
|
||||
|
||||
#### [Comfy Cloud](https://www.comfy.org/cloud)
|
||||
- Our official paid cloud version for those who can't afford local hardware.
|
||||
|
||||
## Examples
|
||||
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
|
||||
@@ -8,7 +8,7 @@ from alembic import context
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base, NAMING_CONVENTION
|
||||
from app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
@@ -51,10 +51,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
naming_convention=NAMING_CONVENTION,
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
Add system_metadata and job_id columns to asset_references.
|
||||
Change preview_id FK from assets.id to asset_references.id.
|
||||
|
||||
Revision ID: 0003_add_metadata_job_id
|
||||
Revises: 0002_merge_to_asset_references
|
||||
Create Date: 2026-03-09
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from app.database.models import NAMING_CONVENTION
|
||||
|
||||
revision = "0003_add_metadata_job_id"
|
||||
down_revision = "0002_merge_to_asset_references"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("system_metadata", sa.JSON(), nullable=True)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column("job_id", sa.String(length=36), nullable=True)
|
||||
)
|
||||
|
||||
# Change preview_id FK from assets.id to asset_references.id (self-ref).
|
||||
# Existing values are asset-content IDs that won't match reference IDs,
|
||||
# so null them out first.
|
||||
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
|
||||
with op.batch_alter_table(
|
||||
"asset_references", naming_convention=NAMING_CONVENTION
|
||||
) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
"fk_asset_references_preview_id_assets", type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_asset_references_preview_id_asset_references",
|
||||
"asset_references",
|
||||
["preview_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
batch_op.create_index(
|
||||
"ix_asset_references_preview_id", ["preview_id"]
|
||||
)
|
||||
|
||||
# Purge any all-null meta rows before adding the constraint
|
||||
op.execute(
|
||||
"DELETE FROM asset_reference_meta"
|
||||
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
|
||||
)
|
||||
with op.batch_alter_table("asset_reference_meta") as batch_op:
|
||||
batch_op.create_check_constraint(
|
||||
"ck_asset_reference_meta_has_value",
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# SQLite doesn't reflect CHECK constraints, so we must declare it
|
||||
# explicitly via table_args for the batch recreate to find it.
|
||||
# Use the fully-rendered constraint name to avoid the naming convention
|
||||
# doubling the prefix.
|
||||
with op.batch_alter_table(
|
||||
"asset_reference_meta",
|
||||
table_args=[
|
||||
sa.CheckConstraint(
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
name="ck_asset_reference_meta_has_value",
|
||||
),
|
||||
],
|
||||
) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
"ck_asset_reference_meta_has_value", type_="check"
|
||||
)
|
||||
|
||||
with op.batch_alter_table(
|
||||
"asset_references", naming_convention=NAMING_CONVENTION
|
||||
) as batch_op:
|
||||
batch_op.drop_index("ix_asset_references_preview_id")
|
||||
batch_op.drop_constraint(
|
||||
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_asset_references_preview_id_assets",
|
||||
"assets",
|
||||
["preview_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.drop_column("job_id")
|
||||
batch_op.drop_column("system_metadata")
|
||||
@@ -13,7 +13,6 @@ from pydantic import ValidationError
|
||||
import folder_paths
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in, schemas_out
|
||||
from app.assets.services import schemas
|
||||
from app.assets.api.schemas_in import (
|
||||
AssetValidationError,
|
||||
UploadError,
|
||||
@@ -39,7 +38,6 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@@ -124,61 +122,6 @@ def _validate_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
|
||||
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
|
||||
if not user_metadata:
|
||||
return None
|
||||
filename = user_metadata.get("filename")
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
if "input" in tags:
|
||||
view_type = "input"
|
||||
elif "output" in tags:
|
||||
view_type = "output"
|
||||
else:
|
||||
return None
|
||||
|
||||
subfolder = ""
|
||||
if "/" in filename:
|
||||
subfolder, filename = filename.rsplit("/", 1)
|
||||
|
||||
encoded_filename = urllib.parse.quote(filename, safe="")
|
||||
url = f"/api/view?type={view_type}&filename={encoded_filename}"
|
||||
if subfolder:
|
||||
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
|
||||
return url
|
||||
|
||||
|
||||
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
|
||||
"""Build an Asset response from a service result."""
|
||||
if result.ref.preview_id:
|
||||
preview_detail = get_asset_detail(result.ref.preview_id)
|
||||
if preview_detail:
|
||||
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
|
||||
else:
|
||||
preview_url = None
|
||||
else:
|
||||
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||
return schemas_out.Asset(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
preview_url=preview_url,
|
||||
preview_id=result.ref.preview_id,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
metadata=result.ref.system_metadata,
|
||||
job_id=result.ref.job_id,
|
||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||
created_at=result.ref.created_at,
|
||||
updated_at=result.ref.updated_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
@_require_assets_feature_enabled
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
@@ -221,7 +164,20 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries = [_build_asset_response(item) for item in result.items]
|
||||
summaries = [
|
||||
schemas_out.AssetSummary(
|
||||
id=item.ref.id,
|
||||
name=item.ref.name,
|
||||
asset_hash=item.asset.hash if item.asset else None,
|
||||
size=int(item.asset.size_bytes) if item.asset else None,
|
||||
mime_type=item.asset.mime_type if item.asset else None,
|
||||
tags=item.tags,
|
||||
created_at=item.ref.created_at,
|
||||
updated_at=item.ref.updated_at,
|
||||
last_access_time=item.ref.last_access_time,
|
||||
)
|
||||
for item in result.items
|
||||
]
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
@@ -251,7 +207,18 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
||||
{"id": reference_id},
|
||||
)
|
||||
|
||||
payload = _build_asset_response(result)
|
||||
payload = schemas_out.AssetDetail(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
)
|
||||
except ValueError as e:
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
||||
@@ -263,7 +230,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
@@ -345,31 +312,32 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||
)
|
||||
|
||||
# Derive name from hash if not provided
|
||||
name = body.name
|
||||
if name is None:
|
||||
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
|
||||
|
||||
result = create_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=name,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
mime_type=body.mime_type,
|
||||
preview_id=body.preview_id,
|
||||
)
|
||||
if result is None:
|
||||
return _build_error_response(
|
||||
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
||||
)
|
||||
|
||||
asset = _build_asset_response(result)
|
||||
payload_out = schemas_out.AssetCreated(
|
||||
**asset.model_dump(),
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash,
|
||||
size=int(result.asset.size_bytes),
|
||||
mime_type=result.asset.mime_type,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
created_new=result.created_new,
|
||||
)
|
||||
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
|
||||
return web.json_response(payload_out.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
@@ -390,8 +358,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
"name": parsed.provided_name,
|
||||
"user_metadata": parsed.user_metadata_raw,
|
||||
"hash": parsed.provided_hash,
|
||||
"mime_type": parsed.provided_mime_type,
|
||||
"preview_id": parsed.provided_preview_id,
|
||||
}
|
||||
)
|
||||
except ValidationError as ve:
|
||||
@@ -420,8 +386,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
mime_type=spec.mime_type,
|
||||
preview_id=spec.preview_id,
|
||||
)
|
||||
if result is None:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
@@ -446,8 +410,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
client_filename=parsed.file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_hash=spec.hash,
|
||||
mime_type=spec.mime_type,
|
||||
preview_id=spec.preview_id,
|
||||
)
|
||||
except AssetValidationError as e:
|
||||
delete_temp_file_if_exists(parsed.tmp_path)
|
||||
@@ -466,13 +428,21 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
asset = _build_asset_response(result)
|
||||
payload_out = schemas_out.AssetCreated(
|
||||
**asset.model_dump(),
|
||||
payload = schemas_out.AssetCreated(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash,
|
||||
size=int(result.asset.size_bytes),
|
||||
mime_type=result.asset.mime_type,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
preview_id=result.ref.preview_id,
|
||||
created_at=result.ref.created_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
created_new=result.created_new,
|
||||
)
|
||||
status = 201 if result.created_new else 200
|
||||
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=status)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -494,9 +464,15 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
preview_id=body.preview_id,
|
||||
)
|
||||
payload = _build_asset_response(result)
|
||||
payload = schemas_out.AssetUpdated(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
tags=result.tags,
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
updated_at=result.ref.updated_at,
|
||||
)
|
||||
except PermissionError as pe:
|
||||
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
||||
except ValueError as ve:
|
||||
@@ -510,7 +486,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@@ -579,7 +555,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
payload = schemas_out.TagsList(
|
||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
@@ -627,7 +603,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
@@ -674,29 +650,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/tags/refine")
|
||||
@_require_assets_feature_enabled
|
||||
async def get_tags_refine(request: web.Request) -> web.Response:
|
||||
"""GET request to get tag histogram for filtered assets."""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _build_validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
tag_counts = list_tag_histogram(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
)
|
||||
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
|
||||
@@ -45,8 +45,6 @@ class ParsedUpload:
|
||||
user_metadata_raw: str | None
|
||||
provided_hash: str | None
|
||||
provided_hash_exists: bool | None
|
||||
provided_mime_type: str | None = None
|
||||
provided_preview_id: str | None = None
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
@@ -100,17 +98,11 @@ class ListAssetsQuery(BaseModel):
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_at_least_one_field(self):
|
||||
if all(
|
||||
v is None
|
||||
for v in (self.name, self.user_metadata, self.preview_id)
|
||||
):
|
||||
raise ValueError(
|
||||
"Provide at least one of: name, user_metadata, preview_id."
|
||||
)
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
@@ -118,11 +110,9 @@ class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str | None = None
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
mime_type: str | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
@@ -148,44 +138,6 @@ class CreateFromHashBody(BaseModel):
|
||||
return []
|
||||
|
||||
|
||||
class TagsRefineQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
limit: conint(ge=1, le=1000) = 100
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@@ -234,25 +186,21 @@ class TagsRemove(TagsAdd):
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
|
||||
- tags: optional list; if provided, first is root ('models'|'input'|'output');
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||
- mime_type: optional MIME type override
|
||||
- preview_id: optional asset_reference ID for preview
|
||||
|
||||
Files are stored using the content hash as filename stem.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
mime_type: str | None = Field(default=None)
|
||||
preview_id: str | None = Field(default=None) # references an asset_reference id
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
@@ -331,7 +279,7 @@ class UploadAssetSpec(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("at least one tag is required for uploads")
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
|
||||
@@ -4,10 +4,7 @@ from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
|
||||
``id`` here is the AssetReference id, not the content-addressed Asset id."""
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
@@ -15,14 +12,8 @@ class Asset(BaseModel):
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
is_immutable: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
prompt_id: str | None = None # deprecated: use job_id
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
@@ -32,16 +23,50 @@ class Asset(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(Asset):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[Asset]
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _serialize_updated_at(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@@ -66,7 +91,3 @@ class TagsRemove(BaseModel):
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagHistogram(BaseModel):
|
||||
tag_counts: dict[str, int]
|
||||
|
||||
@@ -52,8 +52,6 @@ async def parse_multipart_upload(
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
provided_mime_type: str | None = None
|
||||
provided_preview_id: str | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
@@ -130,16 +128,6 @@ async def parse_multipart_upload(
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
elif fname == "id":
|
||||
raise UploadError(
|
||||
400,
|
||||
"UNSUPPORTED_FIELD",
|
||||
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
|
||||
)
|
||||
elif fname == "mime_type":
|
||||
provided_mime_type = ((await field.text()) or "").strip() or None
|
||||
elif fname == "preview_id":
|
||||
provided_preview_id = ((await field.text()) or "").strip() or None
|
||||
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
raise UploadError(
|
||||
@@ -164,8 +152,6 @@ async def parse_multipart_upload(
|
||||
user_metadata_raw=user_metadata_raw,
|
||||
provided_hash=provided_hash,
|
||||
provided_hash_exists=provided_hash_exists,
|
||||
provided_mime_type=provided_mime_type,
|
||||
provided_preview_id=provided_preview_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -45,7 +45,13 @@ class Asset(Base):
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
# preview_id on AssetReference is a self-referential FK to asset_references.id
|
||||
preview_of: Mapped[list[AssetReference]] = relationship(
|
||||
"AssetReference",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
|
||||
foreign_keys=lambda: [AssetReference.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
@@ -85,15 +91,11 @@ class AssetReference(Base):
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
|
||||
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
||||
)
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True)
|
||||
)
|
||||
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True), nullable=True, default=None
|
||||
)
|
||||
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
@@ -113,10 +115,10 @@ class AssetReference(Base):
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_ref: Mapped[AssetReference | None] = relationship(
|
||||
"AssetReference",
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
remote_side=lambda: [AssetReference.id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||
@@ -150,7 +152,6 @@ class AssetReference(Base):
|
||||
Index("ix_asset_references_created_at", "created_at"),
|
||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||
Index("ix_asset_references_preview_id", "preview_id"),
|
||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||
CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
@@ -191,10 +192,6 @@ class AssetReferenceMeta(Base):
|
||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||
CheckConstraint(
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
name="has_value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,21 +31,16 @@ from app.assets.database.queries.asset_reference import (
|
||||
get_unenriched_references,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
insert_reference,
|
||||
list_all_file_paths_by_asset_id,
|
||||
list_references_by_asset_id,
|
||||
list_references_page,
|
||||
mark_references_missing_outside_prefixes,
|
||||
rebuild_metadata_projection,
|
||||
reference_exists,
|
||||
reference_exists_for_asset_id,
|
||||
restore_references_by_paths,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_system_metadata,
|
||||
soft_delete_reference_by_id,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_is_missing_by_asset_id,
|
||||
update_reference_timestamps,
|
||||
update_reference_updated_at,
|
||||
upsert_reference,
|
||||
@@ -59,7 +54,6 @@ from app.assets.database.queries.tags import (
|
||||
bulk_insert_tags_and_meta,
|
||||
ensure_tags_exist,
|
||||
get_reference_tags,
|
||||
list_tag_counts_for_filtered_assets,
|
||||
list_tags_with_usage,
|
||||
remove_missing_tag_for_asset_id,
|
||||
remove_tags_from_reference,
|
||||
@@ -103,26 +97,20 @@ __all__ = [
|
||||
"get_unenriched_references",
|
||||
"get_unreferenced_unhashed_asset_ids",
|
||||
"insert_reference",
|
||||
"list_all_file_paths_by_asset_id",
|
||||
"list_references_by_asset_id",
|
||||
"list_references_page",
|
||||
"list_tag_counts_for_filtered_assets",
|
||||
"list_tags_with_usage",
|
||||
"mark_references_missing_outside_prefixes",
|
||||
"reassign_asset_references",
|
||||
"rebuild_metadata_projection",
|
||||
"reference_exists",
|
||||
"reference_exists_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_tags_from_reference",
|
||||
"restore_references_by_paths",
|
||||
"set_reference_metadata",
|
||||
"set_reference_preview",
|
||||
"set_reference_system_metadata",
|
||||
"soft_delete_reference_by_id",
|
||||
"set_reference_tags",
|
||||
"update_asset_hash_and_mime",
|
||||
"update_is_missing_by_asset_id",
|
||||
"update_reference_access_time",
|
||||
"update_reference_name",
|
||||
"update_reference_timestamps",
|
||||
|
||||
@@ -69,7 +69,7 @@ def upsert_asset(
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and not asset.mime_type:
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
@@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
|
||||
return False
|
||||
if asset_hash is not None:
|
||||
asset.hash = asset_hash
|
||||
if mime_type is not None and not asset.mime_type:
|
||||
if mime_type is not None:
|
||||
asset.mime_type = mime_type
|
||||
return True
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from decimal import Decimal
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import delete, exists, select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, noload
|
||||
@@ -24,14 +24,12 @@ from app.assets.database.models import (
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
MAX_BIND_PARAMS,
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
build_prefix_like_conditions,
|
||||
build_visible_owner_clause,
|
||||
calculate_rows_per_statement,
|
||||
iter_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
def _check_is_scalar(v):
|
||||
@@ -46,6 +44,15 @@ def _check_is_scalar(v):
|
||||
|
||||
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||
"""Convert a scalar value to a typed projection row."""
|
||||
if value is None:
|
||||
return {
|
||||
"key": key,
|
||||
"ordinal": ordinal,
|
||||
"val_str": None,
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
if isinstance(value, bool):
|
||||
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
@@ -59,19 +66,96 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
||||
"""Turn a metadata key/value into typed projection rows."""
|
||||
if value is None:
|
||||
return []
|
||||
return [_scalar_to_row(key, 0, None)]
|
||||
|
||||
if _check_is_scalar(value):
|
||||
return [_scalar_to_row(key, 0, value)]
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(_check_is_scalar(x) for x in value):
|
||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
|
||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
|
||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
|
||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
|
||||
|
||||
return [{"key": key, "ordinal": 0, "val_json": value}]
|
||||
|
||||
|
||||
def _apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_reference_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetReferenceMeta.val_json.is_(None),
|
||||
AssetReferenceMeta.val_str.is_(None),
|
||||
AssetReferenceMeta.val_num.is_(None),
|
||||
AssetReferenceMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def get_reference_by_id(
|
||||
@@ -128,21 +212,6 @@ def reference_exists_for_asset_id(
|
||||
return session.execute(q).first() is not None
|
||||
|
||||
|
||||
def reference_exists(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
) -> bool:
|
||||
"""Return True if a reference with the given ID exists (not soft-deleted)."""
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetReference)
|
||||
.where(AssetReference.id == reference_id)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.limit(1)
|
||||
)
|
||||
return session.execute(q).first() is not None
|
||||
|
||||
|
||||
def insert_reference(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
@@ -267,8 +336,8 @@ def list_references_page(
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
@@ -297,8 +366,8 @@ def list_references_page(
|
||||
count_stmt = count_stmt.where(
|
||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||
)
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int(session.execute(count_stmt).scalar_one() or 0)
|
||||
refs = session.execute(base).unique().scalars().all()
|
||||
@@ -310,7 +379,7 @@ def list_references_page(
|
||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
.order_by(AssetReferenceTag.added_at)
|
||||
)
|
||||
for ref_id, tag_name in rows.all():
|
||||
tag_map[ref_id].append(tag_name)
|
||||
@@ -423,42 +492,6 @@ def update_reference_updated_at(
|
||||
)
|
||||
|
||||
|
||||
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
|
||||
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
|
||||
|
||||
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
|
||||
override system keys of the same name.
|
||||
"""
|
||||
session.execute(
|
||||
delete(AssetReferenceMeta).where(
|
||||
AssetReferenceMeta.asset_reference_id == ref.id
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
|
||||
if not merged:
|
||||
return
|
||||
|
||||
rows: list[AssetReferenceMeta] = []
|
||||
for k, v in merged.items():
|
||||
for r in convert_metadata_to_rows(k, v):
|
||||
rows.append(
|
||||
AssetReferenceMeta(
|
||||
asset_reference_id=ref.id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def set_reference_metadata(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
@@ -472,24 +505,33 @@ def set_reference_metadata(
|
||||
ref.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
|
||||
rebuild_metadata_projection(session, ref)
|
||||
|
||||
|
||||
def set_reference_system_metadata(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
system_metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Set system_metadata on a reference and rebuild the merged projection."""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
ref.system_metadata = system_metadata or {}
|
||||
ref.updated_at = get_utc_now()
|
||||
session.execute(
|
||||
delete(AssetReferenceMeta).where(
|
||||
AssetReferenceMeta.asset_reference_id == reference_id
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
rebuild_metadata_projection(session, ref)
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetReferenceMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in convert_metadata_to_rows(k, v):
|
||||
rows.append(
|
||||
AssetReferenceMeta(
|
||||
asset_reference_id=reference_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def delete_reference_by_id(
|
||||
@@ -529,19 +571,19 @@ def soft_delete_reference_by_id(
|
||||
def set_reference_preview(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
preview_reference_id: str | None = None,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
if preview_reference_id is None:
|
||||
if preview_asset_id is None:
|
||||
ref.preview_id = None
|
||||
else:
|
||||
if not session.get(AssetReference, preview_reference_id):
|
||||
raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
|
||||
ref.preview_id = preview_reference_id
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
ref.preview_id = preview_asset_id
|
||||
|
||||
ref.updated_at = get_utc_now()
|
||||
session.flush()
|
||||
@@ -567,8 +609,6 @@ def list_references_by_asset_id(
|
||||
session.execute(
|
||||
select(AssetReference)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.is_missing == False) # noqa: E712
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.order_by(AssetReference.id.asc())
|
||||
)
|
||||
.scalars()
|
||||
@@ -576,25 +616,6 @@ def list_references_by_asset_id(
|
||||
)
|
||||
|
||||
|
||||
def list_all_file_paths_by_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
) -> list[str]:
|
||||
"""Return every file_path for an asset, including soft-deleted/missing refs.
|
||||
|
||||
Used for orphan cleanup where all on-disk files must be removed.
|
||||
"""
|
||||
return list(
|
||||
session.execute(
|
||||
select(AssetReference.file_path)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.file_path.isnot(None))
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def upsert_reference(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
@@ -834,22 +855,6 @@ def bulk_update_is_missing(
|
||||
return total
|
||||
|
||||
|
||||
def update_is_missing_by_asset_id(
|
||||
session: Session, asset_id: str, value: bool
|
||||
) -> int:
|
||||
"""Set is_missing flag for ALL references belonging to an asset.
|
||||
|
||||
Returns: Number of rows updated
|
||||
"""
|
||||
result = session.execute(
|
||||
sa.update(AssetReference)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.values(is_missing=value)
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
||||
"""Delete references by their IDs.
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Shared utilities for database query modules."""
|
||||
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from typing import Iterable, Sequence
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
|
||||
from app.assets.helpers import escape_sql_like_string, normalize_tags
|
||||
from app.assets.database.models import AssetReference
|
||||
from app.assets.helpers import escape_sql_like_string
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
@@ -54,74 +52,3 @@ def build_prefix_like_conditions(
|
||||
escaped, esc = escape_sql_like_string(base)
|
||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||
return conds
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_reference_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
return sa.not_(
|
||||
sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
@@ -8,15 +8,12 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import (
|
||||
Asset,
|
||||
AssetReference,
|
||||
AssetReferenceMeta,
|
||||
AssetReferenceTag,
|
||||
Tag,
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
@@ -75,9 +72,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
select(AssetReferenceTag.tag_name).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
@@ -120,7 +117,7 @@ def set_reference_tags(
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
|
||||
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
||||
|
||||
|
||||
def add_tags_to_reference(
|
||||
@@ -275,12 +272,6 @@ def list_tags_with_usage(
|
||||
.select_from(AssetReferenceTag)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.is_missing == False, # noqa: E712
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.subquery()
|
||||
@@ -317,12 +308,6 @@ def list_tags_with_usage(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.is_missing == False, # noqa: E712
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
)
|
||||
@@ -335,53 +320,6 @@ def list_tags_with_usage(
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def list_tag_counts_for_filtered_assets(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
"""Return tag counts for assets matching the given filters.
|
||||
|
||||
Uses the same filtering logic as list_references_page but returns
|
||||
{tag_name: count} instead of paginated references.
|
||||
"""
|
||||
# Build a subquery of matching reference IDs
|
||||
ref_sq = (
|
||||
select(AssetReference.id)
|
||||
.join(Asset, Asset.id == AssetReference.asset_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(AssetReference.is_missing == False) # noqa: E712
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
|
||||
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
|
||||
ref_sq = ref_sq.subquery()
|
||||
|
||||
# Count tags across those references
|
||||
q = (
|
||||
select(
|
||||
AssetReferenceTag.tag_name,
|
||||
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||
)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
rows = session.execute(q).all()
|
||||
return {tag_name: int(cnt) for tag_name, cnt in rows}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
tag_rows: list[dict],
|
||||
|
||||
@@ -18,7 +18,7 @@ from app.assets.database.queries import (
|
||||
mark_references_missing_outside_prefixes,
|
||||
reassign_asset_references,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_system_metadata,
|
||||
set_reference_metadata,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
from app.assets.services.bulk_ingest import (
|
||||
@@ -490,8 +490,8 @@ def enrich_asset(
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
user_metadata = metadata.to_user_metadata()
|
||||
set_reference_metadata(session, reference_id, user_metadata)
|
||||
|
||||
if full_hash:
|
||||
existing = get_asset_by_hash(session, full_hash)
|
||||
|
||||
@@ -16,12 +16,10 @@ from app.assets.database.queries import (
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
list_references_page,
|
||||
list_all_file_paths_by_asset_id,
|
||||
list_references_by_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_reference_updated_at,
|
||||
@@ -69,8 +67,6 @@ def update_asset_metadata(
|
||||
user_metadata: UserMetadata = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
@@ -107,21 +103,6 @@ def update_asset_metadata(
|
||||
)
|
||||
touched = True
|
||||
|
||||
if mime_type is not None:
|
||||
updated = update_asset_hash_and_mime(
|
||||
session, asset_id=ref.asset_id, mime_type=mime_type
|
||||
)
|
||||
if updated:
|
||||
touched = True
|
||||
|
||||
if preview_id is not None:
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_reference_id=preview_id,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
update_reference_updated_at(session, reference_id=reference_id)
|
||||
|
||||
@@ -178,9 +159,11 @@ def delete_asset_reference(
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# Orphaned asset - gather ALL file paths (including
|
||||
# soft-deleted / missing refs) so their on-disk files get cleaned up.
|
||||
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
|
||||
# Orphaned asset - delete it and its files
|
||||
refs = list_references_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [
|
||||
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
|
||||
]
|
||||
# Also include the just-deleted file path
|
||||
if file_path:
|
||||
file_paths.append(file_path)
|
||||
@@ -202,7 +185,7 @@ def delete_asset_reference(
|
||||
|
||||
def set_asset_preview(
|
||||
reference_id: str,
|
||||
preview_reference_id: str | None = None,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
@@ -211,7 +194,7 @@ def set_asset_preview(
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_reference_id=preview_reference_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
@@ -280,47 +263,6 @@ def list_assets_page(
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
asset_hash: str,
|
||||
owner_id: str = "",
|
||||
) -> DownloadResolutionResult | None:
|
||||
"""Resolve a blake3 hash to an on-disk file path.
|
||||
|
||||
Only references visible to *owner_id* are considered (owner-less
|
||||
references are always visible).
|
||||
|
||||
Returns a DownloadResolutionResult with abs_path, content_type, and
|
||||
download_name, or None if no asset or live path is found.
|
||||
"""
|
||||
with create_session() as session:
|
||||
asset = queries_get_asset_by_hash(session, asset_hash)
|
||||
if not asset:
|
||||
return None
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
visible = [
|
||||
r for r in refs
|
||||
if r.owner_id == "" or r.owner_id == owner_id
|
||||
]
|
||||
abs_path = select_best_live_path(visible)
|
||||
if not abs_path:
|
||||
return None
|
||||
display_name = os.path.basename(abs_path)
|
||||
for ref in visible:
|
||||
if ref.file_path == abs_path and ref.name:
|
||||
display_name = ref.name
|
||||
break
|
||||
ctype = (
|
||||
asset.mime_type
|
||||
or mimetypes.guess_type(display_name)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
return DownloadResolutionResult(
|
||||
abs_path=abs_path,
|
||||
content_type=ctype,
|
||||
download_name=display_name,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_for_download(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
|
||||
@@ -11,14 +11,13 @@ from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
reference_exists,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
@@ -27,7 +26,6 @@ from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
@@ -67,7 +65,7 @@ def _ingest_file_from_path(
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if not reference_exists(session, preview_id):
|
||||
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
||||
preview_id = None
|
||||
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
@@ -137,8 +135,6 @@ def _register_existing_asset(
|
||||
tags: list[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> RegisterAssetResult:
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
@@ -147,25 +143,14 @@ def _register_existing_asset(
|
||||
if not asset:
|
||||
raise ValueError(f"No asset with hash {asset_hash}")
|
||||
|
||||
if mime_type and not asset.mime_type:
|
||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
|
||||
|
||||
if preview_id:
|
||||
if not reference_exists(session, preview_id):
|
||||
preview_id = None
|
||||
|
||||
ref, ref_created = get_or_create_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
|
||||
if not ref_created:
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
@@ -257,8 +242,6 @@ def upload_from_temp_path(
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_hash: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> UploadResult:
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||
@@ -287,8 +270,6 @@ def upload_from_temp_path(
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
mime_type=mime_type,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
@@ -310,7 +291,7 @@ def upload_from_temp_path(
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
validate_path_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = mime_type or (
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
@@ -334,7 +315,7 @@ def upload_from_temp_path(
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=preview_id,
|
||||
preview_id=None,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags,
|
||||
tag_origin="manual",
|
||||
@@ -361,99 +342,30 @@ def upload_from_temp_path(
|
||||
)
|
||||
|
||||
|
||||
def register_file_in_place(
|
||||
abs_path: str,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
) -> UploadResult:
|
||||
"""Register an already-saved file in the asset database without moving it.
|
||||
|
||||
Tags are derived from the filesystem path (root category + subfolder names),
|
||||
merged with any caller-provided tags, matching the behavior of the scanner.
|
||||
If the path is not under a known root, only the caller-provided tags are used.
|
||||
"""
|
||||
try:
|
||||
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
except ValueError:
|
||||
path_tags = []
|
||||
merged_tags = normalize_tags([*path_tags, *tags])
|
||||
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(abs_path)
|
||||
except ImportError as e:
|
||||
raise DependencyMissingError(str(e))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
content_type = mime_type or (
|
||||
mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
ingest_result = _ingest_file_from_path(
|
||||
abs_path=abs_path,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
tags=merged_tags,
|
||||
tag_origin="upload",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
reference_id = ingest_result.reference_id
|
||||
if not reference_id:
|
||||
raise RuntimeError("failed to create asset reference")
|
||||
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
ref, asset = pair
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
|
||||
return UploadResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created_new=ingest_result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
def create_from_hash(
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> UploadResult | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
|
||||
try:
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
mime_type=mime_type,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
except ValueError:
|
||||
logging.warning("create_from_hash: no asset found for hash %s", canonical)
|
||||
return None
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
|
||||
@@ -25,9 +25,7 @@ class ReferenceData:
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
system_metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
last_access_time: datetime | None = None
|
||||
last_access_time: datetime | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -95,8 +93,6 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
||||
file_path=ref.file_path,
|
||||
user_metadata=ref.user_metadata,
|
||||
preview_id=ref.preview_id,
|
||||
system_metadata=ref.system_metadata,
|
||||
job_id=ref.job_id,
|
||||
created_at=ref.created_at,
|
||||
updated_at=ref.updated_at,
|
||||
last_access_time=ref.last_access_time,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
@@ -8,7 +6,6 @@ from app.assets.database.queries import (
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_reference,
|
||||
)
|
||||
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
|
||||
from app.assets.services.schemas import TagUsage
|
||||
from app.database.db import create_session
|
||||
|
||||
@@ -76,23 +73,3 @@ def list_tags(
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
|
||||
|
||||
def list_tag_histogram(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
with create_session() as session:
|
||||
return list_tag_counts_for_filtered_assets(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@@ -1,18 +1,9 @@
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
NAMING_CONVENTION = {
|
||||
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
||||
pass
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
|
||||
@@ -6,7 +6,6 @@ import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
@@ -378,15 +377,8 @@ class UserManager():
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
dir_name = os.path.dirname(path)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(body)
|
||||
os.replace(tmp_path, path)
|
||||
except:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
|
||||
@@ -83,8 +83,6 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||
|
||||
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
|
||||
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
@@ -149,7 +147,6 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@@ -263,6 +260,4 @@ else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
if args.enable_dynamic_vram:
|
||||
return True
|
||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||
|
||||
@@ -209,39 +209,3 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
||||
output_block[i:i + slice_size].copy_(block)
|
||||
|
||||
return output_fp4, to_blocked(output_block, flatten=False)
|
||||
|
||||
|
||||
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||
def roundup(x_val, multiple):
|
||||
return ((x_val + multiple - 1) // multiple) * multiple
|
||||
|
||||
if pad_32x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 32)
|
||||
padded_cols = roundup(cols, 32)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
|
||||
F8_E4M3_MAX = 448.0
|
||||
E8M0_BIAS = 127
|
||||
BLOCK_SIZE = 32
|
||||
|
||||
rows, cols = x.shape
|
||||
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||
|
||||
# E8M0 block scales (power-of-2 exponents)
|
||||
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||
|
||||
zero_mask = (max_abs == 0)
|
||||
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||
|
||||
# Scale per-block then stochastic round
|
||||
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||
|
||||
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||
|
||||
@@ -343,7 +343,6 @@ class CrossAttention(nn.Module):
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
@@ -413,7 +412,6 @@ class Attention(nn.Module):
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
@@ -65,13 +65,9 @@ class CausalConv3d(nn.Module):
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
|
||||
x = torch.cat(pieces, dim=2)
|
||||
del pieces
|
||||
del cached
|
||||
|
||||
if needs_caching:
|
||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||
elif is_end:
|
||||
self.temporal_cache_state[tid] = (None, True)
|
||||
|
||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from .causal_conv3d import CausalConv3d
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@@ -297,23 +296,7 @@ class Encoder(nn.Module):
|
||||
module.temporal_cache_state.pop(tid, None)
|
||||
|
||||
|
||||
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||
|
||||
def get_max_chunk_size(device: torch.device) -> int:
|
||||
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||
|
||||
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||
return MIN_CHUNK_SIZE
|
||||
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||
return MAX_CHUNK_SIZE
|
||||
|
||||
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||
)
|
||||
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""
|
||||
@@ -541,11 +524,8 @@ class Decoder(nn.Module):
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
output = []
|
||||
max_chunk_size = get_max_chunk_size(sample.device)
|
||||
|
||||
def run_up(idx, sample_ref, ended):
|
||||
sample = sample_ref[0]
|
||||
sample_ref[0] = None
|
||||
def run_up(idx, sample, ended):
|
||||
if idx >= len(self.up_blocks):
|
||||
sample = self.conv_norm_out(sample)
|
||||
if timestep_shift_scale is not None:
|
||||
@@ -556,7 +536,7 @@ class Decoder(nn.Module):
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||
output.append(sample)
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
@@ -573,21 +553,13 @@ class Decoder(nn.Module):
|
||||
return
|
||||
|
||||
total_bytes = sample.numel() * sample.element_size()
|
||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
if num_chunks == 1:
|
||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||
next_sample_ref = [sample]
|
||||
del sample
|
||||
run_up(idx + 1, next_sample_ref, ended)
|
||||
return
|
||||
else:
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
||||
|
||||
run_up(0, [sample], True)
|
||||
run_up(0, sample, True)
|
||||
sample = torch.cat(output, dim=2)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
@@ -99,7 +99,7 @@ class Resample(nn.Module):
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
@@ -109,7 +109,22 @@ class Resample(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
@@ -130,24 +145,19 @@ class Resample(nn.Module):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :]
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||||
# # cache last frame of last two chunk
|
||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
|
||||
deferred_x = feat_cache[idx + 1]
|
||||
if deferred_x is not None:
|
||||
x = torch.cat([deferred_x, x], 2)
|
||||
feat_cache[idx + 1] = None
|
||||
|
||||
if x.shape[2] == 1 and not final:
|
||||
feat_cache[idx + 1] = x
|
||||
x = None
|
||||
|
||||
feat_idx[0] += 2
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
|
||||
@@ -167,12 +177,19 @@ class ResidualBlock(nn.Module):
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -196,7 +213,7 @@ class AttentionBlock(nn.Module):
|
||||
self.proj = ops.Conv2d(dim, dim, 1)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
@@ -266,10 +283,17 @@ class Encoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -279,16 +303,14 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if x is None:
|
||||
return None
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@@ -296,7 +318,14 @@ class Encoder3d(nn.Module):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -364,7 +393,14 @@ class Decoder3d(nn.Module):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
@@ -373,56 +409,42 @@ class Decoder3d(nn.Module):
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
out_chunks = []
|
||||
|
||||
def run_up(layer_idx, x_ref, feat_idx):
|
||||
x = x_ref[0]
|
||||
x_ref[0] = None
|
||||
if layer_idx >= len(self.upsamples):
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||
x = layer(x, feat_cache[feat_idx[0]])
|
||||
feat_cache[feat_idx[0]] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
out_chunks.append(x)
|
||||
return
|
||||
|
||||
layer = self.upsamples[layer_idx]
|
||||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
||||
for frame_idx in range(x.shape[2]):
|
||||
run_up(
|
||||
layer_idx,
|
||||
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
||||
feat_idx.copy(),
|
||||
)
|
||||
del x
|
||||
return
|
||||
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
next_x_ref = [x]
|
||||
del x
|
||||
run_up(layer_idx + 1, next_x_ref, feat_idx)
|
||||
|
||||
run_up(0, [x], feat_idx)
|
||||
return out_chunks
|
||||
return x
|
||||
|
||||
|
||||
def count_cache_layers(model):
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@@ -460,12 +482,11 @@ class WanVAE(nn.Module):
|
||||
conv_idx = [0]
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
t = 1 + ((t - 1) // 4) * 4
|
||||
iter_ = 1 + (t - 1) // 2
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_cache_layers(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
if i == 0:
|
||||
@@ -475,23 +496,20 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx,
|
||||
final=(i == (iter_ - 1)))
|
||||
if out_ is None:
|
||||
continue
|
||||
feat_idx=conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
conv_idx = [0]
|
||||
# z: [b,c,t,h,w]
|
||||
iter_ = 1 + z.shape[2] // 2
|
||||
iter_ = z.shape[2]
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_cache_layers(self.decoder)
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@@ -502,8 +520,8 @@ class WanVAE(nn.Module):
|
||||
feat_idx=conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
out += out_
|
||||
return torch.cat(out, 2)
|
||||
out = torch.cat([out, out_], 2)
|
||||
return out
|
||||
|
||||
@@ -1,68 +1,9 @@
|
||||
import math
|
||||
import ctypes
|
||||
import threading
|
||||
import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
class TensorFileSlice(NamedTuple):
|
||||
file_ref: object
|
||||
thread_id: int
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
def read_tensor_file_slice_into(tensor, destination):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
return False
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||
if info is None:
|
||||
return False
|
||||
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size):
|
||||
return False
|
||||
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
try:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
return True
|
||||
finally:
|
||||
view.release()
|
||||
|
||||
class TensorGeometry(NamedTuple):
|
||||
shape: any
|
||||
dtype: torch.dtype
|
||||
|
||||
@@ -400,7 +400,7 @@ try:
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
@@ -505,28 +505,6 @@ def module_size(module):
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
def module_mmap_residency(module, free=False):
|
||||
mmap_touched_mem = 0
|
||||
module_mem = 0
|
||||
bounced_mmaps = set()
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nbytes
|
||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||
continue
|
||||
mmap_touched_mem += t.nbytes
|
||||
if not free:
|
||||
continue
|
||||
storage._comfy_tensor_mmap_touched = False
|
||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||
if mmap_obj in bounced_mmaps:
|
||||
continue
|
||||
mmap_obj.bounce()
|
||||
bounced_mmaps.add(mmap_obj)
|
||||
return mmap_touched_mem, module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self._set_model(model)
|
||||
@@ -541,7 +519,6 @@ class LoadedModel:
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||
self._patcher_finalizer.atexit = False
|
||||
|
||||
def _switch_parent(self):
|
||||
model = self._parent_model()
|
||||
@@ -555,9 +532,6 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return self.model.model_mmap_residency(free=free)
|
||||
|
||||
def model_loaded_memory(self):
|
||||
return self.model.loaded_size()
|
||||
|
||||
@@ -588,7 +562,6 @@ class LoadedModel:
|
||||
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
self.model_finalizer.atexit = False
|
||||
return real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
@@ -660,7 +633,7 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
@@ -673,14 +646,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
can_unload_sorted = sorted(can_unload)
|
||||
for x in can_unload_sorted:
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
ram_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
@@ -689,18 +661,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if pins_to_free > 0:
|
||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||
if ram_to_free <= 0 and i not in unloaded_model:
|
||||
continue
|
||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||
if resident_memory > 0:
|
||||
if ram_to_free > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@@ -766,27 +729,17 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||
#make this JIT to keep as much pinned as possible.
|
||||
pins_required = model_memory - pinned_memory
|
||||
ram_required = model_memory - resident_memory
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||
#want to do.
|
||||
#FIXME: This should subtract off the to_load current pin consumption.
|
||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required[device],
|
||||
ram_required=total_ram_required[device])
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@@ -1052,12 +1005,6 @@ def intermediate_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def intermediate_dtype():
|
||||
if args.fp16_intermediates:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def vae_device():
|
||||
if args.cpu_vae:
|
||||
return torch.device("cpu")
|
||||
@@ -1278,11 +1225,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
dest_view = dest_views.pop(0)
|
||||
if tensor is None:
|
||||
continue
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||
storage._comfy_tensor_mmap_touched = True
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
@@ -1720,19 +1662,6 @@ def supports_nvfp4_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_mxfp8_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
if torch_version_numeric < (2, 10):
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
|
||||
@@ -297,9 +297,6 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
@@ -1066,10 +1063,6 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def pinned_memory_size(self):
|
||||
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
||||
return 0
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
|
||||
@@ -1660,16 +1653,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
return freed
|
||||
|
||||
def pinned_memory_size(self):
|
||||
total = 0
|
||||
loading = self._load_list(for_dynamic=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
pin = comfy.pinned_memory.get_pin(m)
|
||||
if pin is not None:
|
||||
total += pin.numel() * pin.element_size()
|
||||
return total
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
|
||||
236
comfy/ops.py
236
comfy/ops.py
@@ -306,40 +306,10 @@ class CastWeightBiasOp:
|
||||
bias_function = []
|
||||
|
||||
class disable_weight_init:
|
||||
@staticmethod
|
||||
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||
missing_keys, unexpected_keys, weight_shape,
|
||||
bias_shape=None):
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k, v in state_dict.items():
|
||||
key = k[prefix_len:]
|
||||
if key == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif bias_shape is not None and key == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
if module.weight is None:
|
||||
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
||||
missing_keys.append(prefix + "weight")
|
||||
|
||||
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
||||
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
||||
missing_keys.append(prefix + "bias")
|
||||
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
# don't trust subclasses that BYO state dict loader to call us.
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
return
|
||||
|
||||
@@ -360,21 +330,32 @@ class disable_weight_init:
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
disable_weight_init._lazy_load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
weight_shape=(self.in_features, self.out_features),
|
||||
bias_shape=(self.out_features,),
|
||||
)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k,v in state_dict.items():
|
||||
if k[prefix_len:] == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k[prefix_len:] == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
#Reconcile default construction of the weight if its missing.
|
||||
if self.weight is None:
|
||||
v = torch.zeros(self.in_features, self.out_features)
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"weight")
|
||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||
v = torch.zeros(self.out_features,)
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"bias")
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
@@ -566,53 +547,6 @@ class disable_weight_init:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
||||
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
||||
_freeze=False, device=None, dtype=None):
|
||||
# don't trust subclasses that BYO state dict loader to call us.
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
||||
norm_type, scale_grad_by_freq, sparse, _weight,
|
||||
_freeze, device, dtype)
|
||||
return
|
||||
|
||||
torch.nn.Module.__init__(self)
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
# Keep shape/dtype visible for module introspection without reserving storage.
|
||||
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.bias = None
|
||||
self.weight_comfy_model_dtype = dtype
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
disable_weight_init._lazy_load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
weight_shape=(self.num_embeddings, self.embedding_dim),
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
@@ -776,71 +710,6 @@ from .quant_ops import (
|
||||
)
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input (same as inference path)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
q_input = inp
|
||||
|
||||
w = weight.detach() if weight.requires_grad else weight
|
||||
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Restore original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_f = weight.to(compute_dtype)
|
||||
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
@@ -932,22 +801,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "mxfp8":
|
||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.uint8)
|
||||
|
||||
if block_scale is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
|
||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
@@ -1035,37 +888,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
_use_quantized = (
|
||||
getattr(self, 'layout_type', None) is not None and
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True
|
||||
)
|
||||
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
|
||||
output = QuantLinearFunc.apply(
|
||||
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||
)
|
||||
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@@ -1113,10 +939,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
@@ -1127,15 +950,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
disabled = set()
|
||||
if not nvfp4_compute:
|
||||
disabled.add("nvfp4")
|
||||
if not mxfp8_compute:
|
||||
disabled.add("mxfp8")
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@@ -13,31 +12,18 @@ def pin_memory(module):
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
pin = torch.empty((size,), dtype=torch.uint8)
|
||||
if comfy.model_management.pin_memory(pin):
|
||||
module._pin = pin
|
||||
else:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
try:
|
||||
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||
except RuntimeError:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||
module._pin_hostbuf = hostbuf
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
comfy.model_management.unpin_memory(module._pin)
|
||||
del module._pin
|
||||
del module._pin_hostbuf
|
||||
return size
|
||||
|
||||
@@ -43,18 +43,6 @@ except ImportError as e:
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
|
||||
_CK_MXFP8_AVAILABLE = False
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
||||
_CK_MXFP8_AVAILABLE = True
|
||||
except ImportError:
|
||||
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
||||
|
||||
if not _CK_MXFP8_AVAILABLE:
|
||||
class _CKMxfp8Layout:
|
||||
pass
|
||||
|
||||
import comfy.float
|
||||
|
||||
# ==============================================================================
|
||||
@@ -96,31 +84,6 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||
|
||||
params = cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
@@ -174,8 +137,6 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
@@ -196,14 +157,6 @@ QUANT_ALGOS = {
|
||||
},
|
||||
}
|
||||
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
QUANT_ALGOS["mxfp8"] = {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
|
||||
27
comfy/sd.py
27
comfy/sd.py
@@ -871,16 +871,13 @@ class VAE:
|
||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||
return pixels
|
||||
|
||||
def vae_output_dtype(self):
|
||||
return model_management.intermediate_dtype()
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
@@ -890,16 +887,16 @@ class VAE:
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@@ -908,7 +905,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@@ -917,7 +914,7 @@ class VAE:
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
@@ -926,7 +923,7 @@ class VAE:
|
||||
tile_x = tile_x // extra_channel_size
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
@@ -935,7 +932,7 @@ class VAE:
|
||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
@@ -953,9 +950,9 @@ class VAE:
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
@@ -1028,9 +1025,9 @@ class VAE:
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -20,8 +20,6 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import ctypes
|
||||
import os
|
||||
import comfy.memory_management
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
@@ -34,7 +32,7 @@ from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import mmap
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
@@ -83,17 +81,14 @@ _TYPES = {
|
||||
}
|
||||
|
||||
def load_safetensors(ckpt):
|
||||
import comfy_aimdo.model_mmap
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
mv = memoryview(mapping)
|
||||
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
|
||||
header_size = struct.unpack("<Q", mv[:8])[0]
|
||||
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
|
||||
|
||||
mv = mv[(data_base_offset := 8 + header_size):]
|
||||
mv = mv[8 + header_size:]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
@@ -107,14 +102,7 @@ def load_safetensors(ckpt):
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
storage = tensor.untyped_storage()
|
||||
setattr(storage,
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||
sd[name] = tensor
|
||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
@@ -897,10 +885,6 @@ def set_attr(obj, attr, value):
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
# Clone inference tensors (created under torch.inference_mode) since
|
||||
# their version counter is frozen and nn.Parameter() cannot wrap them.
|
||||
if (not torch.is_inference_mode_enabled()) and value.is_inference():
|
||||
value = value.clone()
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def set_attr_buffer(obj, attr, value):
|
||||
|
||||
@@ -5,9 +5,6 @@ from comfy_api.latest._input import (
|
||||
MaskInput,
|
||||
LatentInput,
|
||||
VideoInput,
|
||||
CurveInput,
|
||||
MonotoneCubicCurve,
|
||||
LinearCurve,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -16,7 +13,4 @@ __all__ = [
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"VideoInput",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||
from .video_types import VideoInput
|
||||
|
||||
__all__ = [
|
||||
@@ -7,7 +7,4 @@ __all__ = [
|
||||
"VideoInput",
|
||||
"MaskInput",
|
||||
"LatentInput",
|
||||
"CurveInput",
|
||||
"MonotoneCubicCurve",
|
||||
"LinearCurve",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
@@ -45,190 +40,3 @@ class LatentInput(TypedDict):
|
||||
"""
|
||||
|
||||
batch_index: Optional[list[int]]
|
||||
|
||||
|
||||
CurvePoint = tuple[float, float]
|
||||
|
||||
|
||||
class CurveInput(ABC):
|
||||
"""Abstract base class for curve inputs.
|
||||
|
||||
Subclasses represent different curve representations (control-point
|
||||
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||
uniform evaluation interface to downstream nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def points(self) -> list[CurvePoint]:
|
||||
"""The control points that define this curve."""
|
||||
|
||||
@abstractmethod
|
||||
def interp(self, x: float) -> float:
|
||||
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||
|
||||
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||
"""Vectorised evaluation over a numpy array of x values.
|
||||
|
||||
Subclasses should override this for better performance. The default
|
||||
falls back to scalar ``interp`` calls.
|
||||
"""
|
||||
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||
|
||||
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||
|
||||
|
||||
class MonotoneCubicCurve(CurveInput):
|
||||
"""Monotone cubic Hermite interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||
backend evaluation matches the editor preview exactly.
|
||||
|
||||
All heavy work (sorting, slope computation) happens once at construction.
|
||||
``interp_array`` is fully vectorised with numpy.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
self._slopes = self._compute_slopes()
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def _compute_slopes(self) -> np.ndarray:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return np.zeros(n, dtype=np.float64)
|
||||
|
||||
dx = np.diff(xs)
|
||||
dy = np.diff(ys)
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||
|
||||
slopes = np.empty(n, dtype=np.float64)
|
||||
slopes[0] = deltas[0]
|
||||
slopes[-1] = deltas[-1]
|
||||
for i in range(1, n - 1):
|
||||
if deltas[i - 1] * deltas[i] <= 0:
|
||||
slopes[i] = 0.0
|
||||
else:
|
||||
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||
|
||||
for i in range(n - 1):
|
||||
if deltas[i] == 0:
|
||||
slopes[i] = 0.0
|
||||
slopes[i + 1] = 0.0
|
||||
else:
|
||||
alpha = slopes[i] / deltas[i]
|
||||
beta = slopes[i + 1] / deltas[i]
|
||||
s = alpha * alpha + beta * beta
|
||||
if s > 9:
|
||||
t = 3 / math.sqrt(s)
|
||||
slopes[i] = t * alpha * deltas[i]
|
||||
slopes[i + 1] = t * beta * deltas[i]
|
||||
return slopes
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
if x <= xs[0]:
|
||||
return float(ys[0])
|
||||
if x >= xs[-1]:
|
||||
return float(ys[-1])
|
||||
|
||||
hi = int(np.searchsorted(xs, x, side='right'))
|
||||
hi = min(hi, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
if dx == 0:
|
||||
return float(ys[lo])
|
||||
|
||||
t = (x - xs[lo]) / dx
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
"""Fully vectorised evaluation using numpy."""
|
||||
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if n == 1:
|
||||
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||
|
||||
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||
lo = hi - 1
|
||||
|
||||
dx = xs[hi] - xs[lo]
|
||||
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||
t2 = t * t
|
||||
t3 = t2 * t
|
||||
|
||||
h00 = 2 * t3 - 3 * t2 + 1
|
||||
h10 = t3 - 2 * t2 + t
|
||||
h01 = -2 * t3 + 3 * t2
|
||||
h11 = t3 - t2
|
||||
|
||||
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MonotoneCubicCurve(points={self._points})"
|
||||
|
||||
|
||||
class LinearCurve(CurveInput):
|
||||
"""Piecewise linear interpolation over control points.
|
||||
|
||||
Mirrors the frontend ``createLinearInterpolator`` in
|
||||
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||
"""
|
||||
|
||||
def __init__(self, control_points: list[CurvePoint]):
|
||||
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||
|
||||
@property
|
||||
def points(self) -> list[CurvePoint]:
|
||||
return list(self._points)
|
||||
|
||||
def interp(self, x: float) -> float:
|
||||
xs, ys = self._xs, self._ys
|
||||
n = len(xs)
|
||||
if n == 0:
|
||||
return 0.0
|
||||
if n == 1:
|
||||
return float(ys[0])
|
||||
return float(np.interp(x, xs, ys))
|
||||
|
||||
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||
if len(self._xs) == 0:
|
||||
return np.zeros_like(xs_in, dtype=np.float64)
|
||||
if len(self._xs) == 1:
|
||||
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||
return np.interp(xs_in, self._xs, self._ys)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LinearCurve(points={self._points})"
|
||||
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from comfy.samplers import CFGGuider, Sampler
|
||||
from comfy.sd import CLIP, VAE
|
||||
from comfy.sd import StyleModel as StyleModel_
|
||||
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
@@ -1243,8 +1243,7 @@ class BoundingBox(ComfyTypeIO):
|
||||
@comfytype(io_type="CURVE")
|
||||
class Curve(ComfyTypeIO):
|
||||
CurvePoint = tuple[float, float]
|
||||
if TYPE_CHECKING:
|
||||
Type = CurveInput_
|
||||
Type = list[CurvePoint]
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
@@ -1253,12 +1252,6 @@ class Curve(ComfyTypeIO):
|
||||
if default is None:
|
||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.default is not None:
|
||||
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||
return d
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
|
||||
@@ -1459,7 +1459,6 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
node_id="KlingOmniProEditVideoNode",
|
||||
display_name="Kling 3.0 Omni Edit Video",
|
||||
category="api node/video/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Edit an existing video with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
|
||||
@@ -833,7 +833,6 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
||||
node_id="RecraftVectorizeImageNode",
|
||||
display_name="Recraft Vectorize Image",
|
||||
category="api node/image/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Generates SVG synchronously from an input image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
|
||||
@@ -19,7 +19,6 @@ class EmptyLatentAudio(IO.ComfyNode):
|
||||
node_id="EmptyLatentAudio",
|
||||
display_name="Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
@@ -186,7 +185,6 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
search_aliases=["export mp3"],
|
||||
display_name="Save Audio (MP3)",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.input import CurveInput, MonotoneCubicCurve, LinearCurve
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class CurveEditor(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CurveEditor",
|
||||
display_name="Curve Editor",
|
||||
category="utils",
|
||||
inputs=[
|
||||
io.Curve.Input("curve"),
|
||||
],
|
||||
outputs=[
|
||||
io.Curve.Output("curve"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, curve) -> io.NodeOutput:
|
||||
if isinstance(curve, CurveInput):
|
||||
return io.NodeOutput(curve)
|
||||
raw_points = curve["points"] if isinstance(curve, dict) else curve
|
||||
points = [(float(x), float(y)) for x, y in raw_points]
|
||||
interpolation = curve.get("interpolation", "monotone_cubic") if isinstance(curve, dict) else "monotone_cubic"
|
||||
if interpolation == "linear":
|
||||
return io.NodeOutput(LinearCurve(points))
|
||||
return io.NodeOutput(MonotoneCubicCurve(points))
|
||||
|
||||
|
||||
class CurveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [CurveEditor]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return CurveExtension()
|
||||
@@ -14,7 +14,6 @@ class ImageCompare(IO.ComfyNode):
|
||||
display_name="Image Compare",
|
||||
description="Compares two images side by side with a slider.",
|
||||
category="image",
|
||||
essentials_category="Image Tools",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
|
||||
@@ -58,7 +58,6 @@ class ImageCropV2(IO.ComfyNode):
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
|
||||
@@ -21,7 +21,6 @@ class Blend(io.ComfyNode):
|
||||
node_id="ImageBlend",
|
||||
display_name="Image Blend",
|
||||
category="image/postprocessing",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
io.Image.Input("image1"),
|
||||
io.Image.Input("image2"),
|
||||
|
||||
@@ -15,7 +15,6 @@ import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
@@ -139,7 +138,6 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
training_dtype=torch.bfloat16,
|
||||
real_dataset=None,
|
||||
bucket_latents=None,
|
||||
use_grad_scaler=False,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
@@ -154,8 +152,6 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.bucket_latents: list[torch.Tensor] | None = (
|
||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||
)
|
||||
# GradScaler for fp16 training
|
||||
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||
# Precompute bucket offsets and weights for sampling
|
||||
if bucket_latents is not None:
|
||||
self._init_bucket_data(bucket_latents)
|
||||
@@ -208,13 +204,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
batch_sigmas.requires_grad_(True),
|
||||
**batch_extra_args,
|
||||
)
|
||||
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
if bwd:
|
||||
bwd_loss = loss / self.grad_acc
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(bwd_loss).backward()
|
||||
else:
|
||||
bwd_loss.backward()
|
||||
bwd_loss.backward()
|
||||
return loss
|
||||
|
||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||
@@ -314,10 +307,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(total_loss).backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
@@ -358,18 +348,12 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.unscale_(self.optimizer)
|
||||
for param_groups in self.optimizer.param_groups:
|
||||
for param in param_groups["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
ui_pbar.update(1)
|
||||
torch.cuda.empty_cache()
|
||||
@@ -1020,9 +1004,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
),
|
||||
io.Combo.Input(
|
||||
"training_dtype",
|
||||
options=["bf16", "fp32", "none"],
|
||||
options=["bf16", "fp32"],
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||
tooltip="The dtype to use for training.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"lora_dtype",
|
||||
@@ -1051,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Offload model weights to CPU during training to save GPU memory.",
|
||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
@@ -1136,32 +1120,22 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
use_grad_scaler = False
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
if mp.is_dynamic():
|
||||
if not bypass_mode:
|
||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||
bypass_mode = True
|
||||
offloading = True
|
||||
elif offloading:
|
||||
if not bypass_mode:
|
||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, latents_dtype, bucket_mode
|
||||
latents, dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
@@ -1227,7 +1201,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
bucket_latents=latents,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
else:
|
||||
train_sampler = TrainSampler(
|
||||
@@ -1240,7 +1213,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
@@ -1365,7 +1337,7 @@ class SaveLoRA(io.ComfyNode):
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
optional=True,
|
||||
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
|
||||
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
|
||||
4
main.py
4
main.py
@@ -206,8 +206,8 @@ import hook_breaker_ac10a0
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
|
||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
|
||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if args.verbose == 'DEBUG':
|
||||
|
||||
@@ -1 +1 @@
|
||||
comfyui_manager==4.1b6
|
||||
comfyui_manager==4.1b2
|
||||
17
nodes.py
17
nodes.py
@@ -81,7 +81,6 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
|
||||
|
||||
class ConditioningCombine:
|
||||
ESSENTIALS_CATEGORY = "Image Generation"
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
|
||||
@@ -952,7 +951,7 @@ class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
@@ -1212,6 +1211,9 @@ class GLIGENTextBoxApply:
|
||||
return (c, )
|
||||
|
||||
class EmptyLatentImage:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@@ -1230,7 +1232,7 @@ class EmptyLatentImage:
|
||||
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
|
||||
|
||||
|
||||
@@ -1722,8 +1724,6 @@ class LoadImage:
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
@@ -1748,8 +1748,8 @@ class LoadImage:
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image.to(dtype=dtype))
|
||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
@@ -1779,7 +1779,6 @@ class LoadImage:
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
@@ -1888,7 +1887,6 @@ class ImageScale:
|
||||
return (s,)
|
||||
|
||||
class ImageScaleBy:
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
|
||||
@classmethod
|
||||
@@ -2453,7 +2451,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-workflow-templates==0.9.26
|
||||
comfyui-frontend-package==1.41.19
|
||||
comfyui-workflow-templates==0.9.21
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@@ -23,7 +23,7 @@ SQLAlchemy
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo>=0.2.12
|
||||
comfy-aimdo>=0.2.10
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
79
server.py
79
server.py
@@ -35,8 +35,6 @@ from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@@ -421,24 +419,7 @@ class PromptServer():
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
|
||||
|
||||
if args.enable_assets:
|
||||
try:
|
||||
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
||||
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
||||
resp["asset"] = {
|
||||
"id": result.ref.id,
|
||||
"name": result.ref.name,
|
||||
"asset_hash": result.asset.hash,
|
||||
"size": result.asset.size_bytes,
|
||||
"mime_type": result.asset.mime_type,
|
||||
"tags": result.tags,
|
||||
}
|
||||
except Exception:
|
||||
logging.warning("Failed to register uploaded image as asset", exc_info=True)
|
||||
|
||||
return web.json_response(resp)
|
||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@@ -498,43 +479,30 @@ class PromptServer():
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
# The frontend's LoadImage combo widget uses asset_hash values
|
||||
# (e.g. "blake3:...") as widget values. When litegraph renders the
|
||||
# node preview, it constructs /view?filename=<asset_hash>, so this
|
||||
# endpoint must resolve blake3 hashes to their on-disk file paths.
|
||||
if filename.startswith("blake3:"):
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
result = resolve_hash_to_path(filename, owner_id=owner_id)
|
||||
if result is None:
|
||||
return web.Response(status=404)
|
||||
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
|
||||
else:
|
||||
resolved_content_type = None
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
if "subfolder" in request.rel_url.query:
|
||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
|
||||
if "subfolder" in request.rel_url.query:
|
||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
|
||||
if os.path.isfile(file):
|
||||
if 'preview' in request.rel_url.query:
|
||||
@@ -594,13 +562,8 @@ class PromptServer():
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
# Use the content type from asset resolution if available,
|
||||
# otherwise guess from the filename.
|
||||
content_type = (
|
||||
resolved_content_type
|
||||
or mimetypes.guess_type(filename)[0]
|
||||
or 'application/octet-stream'
|
||||
)
|
||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
|
||||
|
||||
This catches problems like unnamed FK constraints that prevent batch-mode
|
||||
drop_constraint from working on real SQLite files (see MB-2).
|
||||
|
||||
Migrations 0001 and 0002 are already shipped, so we only exercise
|
||||
upgrade/downgrade for 0003+.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
|
||||
# Oldest shipped revision — we upgrade to here as a baseline and never
|
||||
# downgrade past it.
|
||||
_BASELINE = "0002_merge_to_asset_references"
|
||||
|
||||
|
||||
def _make_config(db_path: str) -> Config:
|
||||
root = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
|
||||
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_db(tmp_path):
|
||||
"""Yield an alembic Config pre-upgraded to the baseline revision."""
|
||||
db_path = str(tmp_path / "test_migration.db")
|
||||
cfg = _make_config(db_path)
|
||||
command.upgrade(cfg, _BASELINE)
|
||||
yield cfg
|
||||
|
||||
|
||||
def test_upgrade_to_head(migration_db):
|
||||
"""Upgrade from baseline to head must succeed on a file-backed DB."""
|
||||
command.upgrade(migration_db, "head")
|
||||
|
||||
|
||||
def test_downgrade_to_baseline(migration_db):
|
||||
"""Upgrade to head then downgrade back to baseline."""
|
||||
command.upgrade(migration_db, "head")
|
||||
command.downgrade(migration_db, _BASELINE)
|
||||
|
||||
|
||||
def test_upgrade_downgrade_cycle(migration_db):
|
||||
"""Full cycle: upgrade → downgrade → upgrade again."""
|
||||
command.upgrade(migration_db, "head")
|
||||
command.downgrade(migration_db, _BASELINE)
|
||||
command.upgrade(migration_db, "head")
|
||||
@@ -10,7 +10,6 @@ from app.assets.database.queries import (
|
||||
get_asset_by_hash,
|
||||
upsert_asset,
|
||||
bulk_insert_assets,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
|
||||
|
||||
@@ -143,45 +142,3 @@ class TestBulkInsertAssets:
|
||||
session.commit()
|
||||
|
||||
assert session.query(Asset).count() == 200
|
||||
|
||||
|
||||
class TestMimeTypeImmutability:
|
||||
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"initial_mime,second_mime,expected_mime",
|
||||
[
|
||||
("image/png", "image/jpeg", "image/png"),
|
||||
(None, "image/png", "image/png"),
|
||||
],
|
||||
ids=["preserves_existing", "fills_null"],
|
||||
)
|
||||
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
|
||||
h = f"blake3:upsert_{initial_mime}_{second_mime}"
|
||||
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
|
||||
session.commit()
|
||||
|
||||
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
|
||||
assert created is False
|
||||
assert asset.mime_type == expected_mime
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
|
||||
[
|
||||
(None, "image/png", None, "image/png", "blake3:upd0"),
|
||||
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
|
||||
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
|
||||
],
|
||||
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
|
||||
)
|
||||
def test_update_asset_hash_and_mime_immutability(
|
||||
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
|
||||
):
|
||||
h = expected_hash.removesuffix("_new")
|
||||
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
|
||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
|
||||
assert asset.mime_type == expected_mime
|
||||
assert asset.hash == expected_hash
|
||||
|
||||
@@ -242,24 +242,22 @@ class TestSetReferencePreview:
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
session.commit()
|
||||
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
|
||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.preview_id == preview_ref.id
|
||||
assert ref.preview_id == preview_asset.id
|
||||
|
||||
def test_clears_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
ref.preview_id = preview_ref.id
|
||||
ref.preview_id = preview_asset.id
|
||||
session.commit()
|
||||
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
|
||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
@@ -267,15 +265,15 @@ class TestSetReferencePreview:
|
||||
|
||||
def test_raises_for_nonexistent_reference(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
|
||||
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None)
|
||||
|
||||
def test_raises_for_nonexistent_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(ValueError, match="Preview AssetReference"):
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
|
||||
with pytest.raises(ValueError, match="Preview Asset"):
|
||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent")
|
||||
|
||||
|
||||
class TestInsertReference:
|
||||
@@ -353,14 +351,13 @@ class TestUpdateReferenceTimestamps:
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
session.commit()
|
||||
|
||||
update_reference_timestamps(session, ref, preview_id=preview_ref.id)
|
||||
update_reference_timestamps(session, ref, preview_id=preview_asset.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.preview_id == preview_ref.id
|
||||
assert ref.preview_id == preview_asset.id
|
||||
|
||||
|
||||
class TestSetReferenceMetadata:
|
||||
|
||||
@@ -20,7 +20,6 @@ def _make_reference(
|
||||
asset: Asset,
|
||||
name: str,
|
||||
metadata: dict | None = None,
|
||||
system_metadata: dict | None = None,
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
@@ -28,7 +27,6 @@ def _make_reference(
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
user_metadata=metadata,
|
||||
system_metadata=system_metadata,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
@@ -36,10 +34,8 @@ def _make_reference(
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
|
||||
# Build merged projection: {**system_metadata, **user_metadata}
|
||||
merged = {**(system_metadata or {}), **(metadata or {})}
|
||||
if merged:
|
||||
for key, val in merged.items():
|
||||
if metadata:
|
||||
for key, val in metadata.items():
|
||||
for row in convert_metadata_to_rows(key, val):
|
||||
meta_row = AssetReferenceMeta(
|
||||
asset_reference_id=ref.id,
|
||||
@@ -186,46 +182,3 @@ class TestMetadataFilterEmptyDict:
|
||||
|
||||
refs, _, total = list_references_page(session, metadata_filter={})
|
||||
assert total == 2
|
||||
|
||||
|
||||
class TestSystemMetadataProjection:
|
||||
"""Tests for system_metadata merging into the filter projection."""
|
||||
|
||||
def test_system_metadata_keys_are_filterable(self, session: Session):
|
||||
"""system_metadata keys should appear in the merged projection."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(
|
||||
session, asset, "with_sys",
|
||||
system_metadata={"source": "scanner"},
|
||||
)
|
||||
_make_reference(session, asset, "without_sys")
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"source": "scanner"}
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].name == "with_sys"
|
||||
|
||||
def test_user_metadata_overrides_system_metadata(self, session: Session):
|
||||
"""user_metadata should win when both have the same key."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(
|
||||
session, asset, "overridden",
|
||||
metadata={"origin": "user_upload"},
|
||||
system_metadata={"origin": "auto_scan"},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Should match the user value, not the system value
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"origin": "user_upload"}
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].name == "overridden"
|
||||
|
||||
# Should NOT match the system value (it was overridden)
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"origin": "auto_scan"}
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
@@ -11,7 +11,6 @@ from app.assets.services import (
|
||||
delete_asset_reference,
|
||||
set_asset_preview,
|
||||
)
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
||||
@@ -220,33 +219,31 @@ class TestSetAssetPreview:
|
||||
asset = _make_asset(session, hash_val="blake3:main")
|
||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
ref_id = ref.id
|
||||
preview_ref_id = preview_ref.id
|
||||
preview_id = preview_asset.id
|
||||
session.commit()
|
||||
|
||||
set_asset_preview(
|
||||
reference_id=ref_id,
|
||||
preview_reference_id=preview_ref_id,
|
||||
preview_asset_id=preview_id,
|
||||
)
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
session.expire_all()
|
||||
updated_ref = session.get(AssetReference, ref_id)
|
||||
assert updated_ref.preview_id == preview_ref_id
|
||||
assert updated_ref.preview_id == preview_id
|
||||
|
||||
def test_clears_preview(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
ref.preview_id = preview_ref.id
|
||||
ref.preview_id = preview_asset.id
|
||||
ref_id = ref.id
|
||||
session.commit()
|
||||
|
||||
set_asset_preview(
|
||||
reference_id=ref_id,
|
||||
preview_reference_id=None,
|
||||
preview_asset_id=None,
|
||||
)
|
||||
|
||||
# Verify by re-fetching from DB
|
||||
@@ -266,45 +263,6 @@ class TestSetAssetPreview:
|
||||
with pytest.raises(PermissionError, match="not owner"):
|
||||
set_asset_preview(
|
||||
reference_id=ref.id,
|
||||
preview_reference_id=None,
|
||||
preview_asset_id=None,
|
||||
owner_id="user2",
|
||||
)
|
||||
|
||||
|
||||
class TestResolveHashToPath:
|
||||
def test_returns_none_for_unknown_hash(self, mock_create_session):
|
||||
result = resolve_hash_to_path("blake3:" + "a" * 64)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ref_owner, query_owner, expect_found",
|
||||
[
|
||||
("user1", "user1", True),
|
||||
("user1", "user2", False),
|
||||
("", "anyone", True),
|
||||
("", "", True),
|
||||
],
|
||||
ids=[
|
||||
"owner_sees_own_ref",
|
||||
"other_owner_blocked",
|
||||
"ownerless_visible_to_anyone",
|
||||
"ownerless_visible_to_empty",
|
||||
],
|
||||
)
|
||||
def test_owner_visibility(
|
||||
self, ref_owner, query_owner, expect_found,
|
||||
mock_create_session, session: Session, temp_dir,
|
||||
):
|
||||
f = temp_dir / "file.bin"
|
||||
f.write_bytes(b"data")
|
||||
asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
|
||||
ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
|
||||
ref.file_path = str(f)
|
||||
session.commit()
|
||||
|
||||
result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
|
||||
if expect_found:
|
||||
assert result is not None
|
||||
assert result.abs_path == str(f)
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
@@ -113,19 +113,11 @@ class TestIngestFileFromPath:
|
||||
file_path = temp_dir / "with_preview.bin"
|
||||
file_path.write_bytes(b"data")
|
||||
|
||||
# Create a preview asset and reference
|
||||
# Create a preview asset first
|
||||
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
||||
session.add(preview_asset)
|
||||
session.flush()
|
||||
from app.assets.helpers import get_utc_now
|
||||
now = get_utc_now()
|
||||
preview_ref = AssetReference(
|
||||
asset_id=preview_asset.id, name="preview.png", owner_id="",
|
||||
created_at=now, updated_at=now, last_access_time=now,
|
||||
)
|
||||
session.add(preview_ref)
|
||||
session.commit()
|
||||
preview_id = preview_ref.id
|
||||
preview_id = preview_asset.id
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(file_path),
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
"""Tests for list_tag_histogram service function."""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_reference(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str = "test",
|
||||
owner_id: str = "",
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
return ref
|
||||
|
||||
|
||||
class TestListTagHistogram:
|
||||
def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
a1 = _make_asset(session, "blake3:aaa")
|
||||
r1 = _make_reference(session, a1, name="r1")
|
||||
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
|
||||
|
||||
a2 = _make_asset(session, "blake3:bbb")
|
||||
r2 = _make_reference(session, a2, name="r2")
|
||||
add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram()
|
||||
|
||||
assert result["alpha"] == 2
|
||||
assert result["beta"] == 1
|
||||
|
||||
def test_empty_when_no_assets(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["unused"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram()
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_include_tags_filter(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["models", "loras", "input"])
|
||||
a1 = _make_asset(session, "blake3:aaa")
|
||||
r1 = _make_reference(session, a1, name="r1")
|
||||
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
||||
|
||||
a2 = _make_asset(session, "blake3:bbb")
|
||||
r2 = _make_reference(session, a2, name="r2")
|
||||
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram(include_tags=["models"])
|
||||
|
||||
# Only r1 has "models", so only its tags appear
|
||||
assert "models" in result
|
||||
assert "loras" in result
|
||||
assert "input" not in result
|
||||
|
||||
def test_exclude_tags_filter(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["models", "loras", "input"])
|
||||
a1 = _make_asset(session, "blake3:aaa")
|
||||
r1 = _make_reference(session, a1, name="r1")
|
||||
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
||||
|
||||
a2 = _make_asset(session, "blake3:bbb")
|
||||
r2 = _make_reference(session, a2, name="r2")
|
||||
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram(exclude_tags=["models"])
|
||||
|
||||
# r1 excluded, only r2's tags remain
|
||||
assert "input" in result
|
||||
assert "loras" not in result
|
||||
|
||||
def test_name_contains_filter(self, mock_create_session, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
a1 = _make_asset(session, "blake3:aaa")
|
||||
r1 = _make_reference(session, a1, name="my_model.safetensors")
|
||||
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
|
||||
|
||||
a2 = _make_asset(session, "blake3:bbb")
|
||||
r2 = _make_reference(session, a2, name="picture.png")
|
||||
add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram(name_contains="model")
|
||||
|
||||
assert "alpha" in result
|
||||
assert "beta" not in result
|
||||
|
||||
def test_limit_caps_results(self, mock_create_session, session: Session):
|
||||
tags = [f"tag{i}" for i in range(10)]
|
||||
ensure_tags_exist(session, tags)
|
||||
a = _make_asset(session, "blake3:aaa")
|
||||
r = _make_reference(session, a, name="r1")
|
||||
add_tags_to_reference(session, reference_id=r.id, tags=tags)
|
||||
session.commit()
|
||||
|
||||
result = list_tag_histogram(limit=3)
|
||||
|
||||
assert len(result) == 3
|
||||
@@ -243,15 +243,6 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
|
||||
|
||||
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||
root: str,
|
||||
|
||||
Reference in New Issue
Block a user