mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 11:21:34 +00:00
Compare commits
4 Commits
luke-mino-
...
fix/api-no
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ecca084b5 | ||
|
|
e2096c2fc4 | ||
|
|
1ea727b4bc | ||
|
|
858977ab10 |
11
README.md
11
README.md
@@ -38,8 +38,6 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
|||||||
|
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
### Local
|
|
||||||
|
|
||||||
#### [Desktop Application](https://www.comfy.org/download)
|
#### [Desktop Application](https://www.comfy.org/download)
|
||||||
- The easiest way to get started.
|
- The easiest way to get started.
|
||||||
- Available on Windows & macOS.
|
- Available on Windows & macOS.
|
||||||
@@ -51,13 +49,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
|||||||
#### [Manual Install](#manual-install-windows-linux)
|
#### [Manual Install](#manual-install-windows-linux)
|
||||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||||
|
|
||||||
### Cloud
|
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
|
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||||
#### [Comfy Cloud](https://www.comfy.org/cloud)
|
|
||||||
- Our official paid cloud version for those who can't afford local hardware.
|
|
||||||
|
|
||||||
## Examples
|
|
||||||
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from alembic import context
|
|||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
|
|
||||||
from app.database.models import Base, NAMING_CONVENTION
|
from app.database.models import Base
|
||||||
target_metadata = Base.metadata
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
@@ -51,10 +51,7 @@ def run_migrations_online() -> None:
|
|||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(
|
||||||
connection=connection,
|
connection=connection, target_metadata=target_metadata
|
||||||
target_metadata=target_metadata,
|
|
||||||
render_as_batch=True,
|
|
||||||
naming_convention=NAMING_CONVENTION,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
|
|||||||
@@ -1,98 +0,0 @@
|
|||||||
"""
|
|
||||||
Add system_metadata and job_id columns to asset_references.
|
|
||||||
Change preview_id FK from assets.id to asset_references.id.
|
|
||||||
|
|
||||||
Revision ID: 0003_add_metadata_job_id
|
|
||||||
Revises: 0002_merge_to_asset_references
|
|
||||||
Create Date: 2026-03-09
|
|
||||||
"""
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from app.database.models import NAMING_CONVENTION
|
|
||||||
|
|
||||||
revision = "0003_add_metadata_job_id"
|
|
||||||
down_revision = "0002_merge_to_asset_references"
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
with op.batch_alter_table("asset_references") as batch_op:
|
|
||||||
batch_op.add_column(
|
|
||||||
sa.Column("system_metadata", sa.JSON(), nullable=True)
|
|
||||||
)
|
|
||||||
batch_op.add_column(
|
|
||||||
sa.Column("job_id", sa.String(length=36), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Change preview_id FK from assets.id to asset_references.id (self-ref).
|
|
||||||
# Existing values are asset-content IDs that won't match reference IDs,
|
|
||||||
# so null them out first.
|
|
||||||
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
|
|
||||||
with op.batch_alter_table(
|
|
||||||
"asset_references", naming_convention=NAMING_CONVENTION
|
|
||||||
) as batch_op:
|
|
||||||
batch_op.drop_constraint(
|
|
||||||
"fk_asset_references_preview_id_assets", type_="foreignkey"
|
|
||||||
)
|
|
||||||
batch_op.create_foreign_key(
|
|
||||||
"fk_asset_references_preview_id_asset_references",
|
|
||||||
"asset_references",
|
|
||||||
["preview_id"],
|
|
||||||
["id"],
|
|
||||||
ondelete="SET NULL",
|
|
||||||
)
|
|
||||||
batch_op.create_index(
|
|
||||||
"ix_asset_references_preview_id", ["preview_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Purge any all-null meta rows before adding the constraint
|
|
||||||
op.execute(
|
|
||||||
"DELETE FROM asset_reference_meta"
|
|
||||||
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
|
|
||||||
)
|
|
||||||
with op.batch_alter_table("asset_reference_meta") as batch_op:
|
|
||||||
batch_op.create_check_constraint(
|
|
||||||
"ck_asset_reference_meta_has_value",
|
|
||||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# SQLite doesn't reflect CHECK constraints, so we must declare it
|
|
||||||
# explicitly via table_args for the batch recreate to find it.
|
|
||||||
# Use the fully-rendered constraint name to avoid the naming convention
|
|
||||||
# doubling the prefix.
|
|
||||||
with op.batch_alter_table(
|
|
||||||
"asset_reference_meta",
|
|
||||||
table_args=[
|
|
||||||
sa.CheckConstraint(
|
|
||||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
|
||||||
name="ck_asset_reference_meta_has_value",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
) as batch_op:
|
|
||||||
batch_op.drop_constraint(
|
|
||||||
"ck_asset_reference_meta_has_value", type_="check"
|
|
||||||
)
|
|
||||||
|
|
||||||
with op.batch_alter_table(
|
|
||||||
"asset_references", naming_convention=NAMING_CONVENTION
|
|
||||||
) as batch_op:
|
|
||||||
batch_op.drop_index("ix_asset_references_preview_id")
|
|
||||||
batch_op.drop_constraint(
|
|
||||||
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
|
|
||||||
)
|
|
||||||
batch_op.create_foreign_key(
|
|
||||||
"fk_asset_references_preview_id_assets",
|
|
||||||
"assets",
|
|
||||||
["preview_id"],
|
|
||||||
["id"],
|
|
||||||
ondelete="SET NULL",
|
|
||||||
)
|
|
||||||
|
|
||||||
with op.batch_alter_table("asset_references") as batch_op:
|
|
||||||
batch_op.drop_column("job_id")
|
|
||||||
batch_op.drop_column("system_metadata")
|
|
||||||
@@ -13,7 +13,6 @@ from pydantic import ValidationError
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from app import user_manager
|
from app import user_manager
|
||||||
from app.assets.api import schemas_in, schemas_out
|
from app.assets.api import schemas_in, schemas_out
|
||||||
from app.assets.services import schemas
|
|
||||||
from app.assets.api.schemas_in import (
|
from app.assets.api.schemas_in import (
|
||||||
AssetValidationError,
|
AssetValidationError,
|
||||||
UploadError,
|
UploadError,
|
||||||
@@ -39,7 +38,6 @@ from app.assets.services import (
|
|||||||
update_asset_metadata,
|
update_asset_metadata,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
)
|
)
|
||||||
from app.assets.services.tagging import list_tag_histogram
|
|
||||||
|
|
||||||
ROUTES = web.RouteTableDef()
|
ROUTES = web.RouteTableDef()
|
||||||
USER_MANAGER: user_manager.UserManager | None = None
|
USER_MANAGER: user_manager.UserManager | None = None
|
||||||
@@ -124,61 +122,6 @@ def _validate_sort_field(requested: str | None) -> str:
|
|||||||
return "created_at"
|
return "created_at"
|
||||||
|
|
||||||
|
|
||||||
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
|
|
||||||
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
|
|
||||||
if not user_metadata:
|
|
||||||
return None
|
|
||||||
filename = user_metadata.get("filename")
|
|
||||||
if not filename:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if "input" in tags:
|
|
||||||
view_type = "input"
|
|
||||||
elif "output" in tags:
|
|
||||||
view_type = "output"
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
subfolder = ""
|
|
||||||
if "/" in filename:
|
|
||||||
subfolder, filename = filename.rsplit("/", 1)
|
|
||||||
|
|
||||||
encoded_filename = urllib.parse.quote(filename, safe="")
|
|
||||||
url = f"/api/view?type={view_type}&filename={encoded_filename}"
|
|
||||||
if subfolder:
|
|
||||||
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
|
|
||||||
"""Build an Asset response from a service result."""
|
|
||||||
if result.ref.preview_id:
|
|
||||||
preview_detail = get_asset_detail(result.ref.preview_id)
|
|
||||||
if preview_detail:
|
|
||||||
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
|
|
||||||
else:
|
|
||||||
preview_url = None
|
|
||||||
else:
|
|
||||||
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
|
||||||
return schemas_out.Asset(
|
|
||||||
id=result.ref.id,
|
|
||||||
name=result.ref.name,
|
|
||||||
asset_hash=result.asset.hash if result.asset else None,
|
|
||||||
size=int(result.asset.size_bytes) if result.asset else None,
|
|
||||||
mime_type=result.asset.mime_type if result.asset else None,
|
|
||||||
tags=result.tags,
|
|
||||||
preview_url=preview_url,
|
|
||||||
preview_id=result.ref.preview_id,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
metadata=result.ref.system_metadata,
|
|
||||||
job_id=result.ref.job_id,
|
|
||||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
|
||||||
created_at=result.ref.created_at,
|
|
||||||
updated_at=result.ref.updated_at,
|
|
||||||
last_access_time=result.ref.last_access_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.head("/api/assets/hash/{hash}")
|
@ROUTES.head("/api/assets/hash/{hash}")
|
||||||
@_require_assets_feature_enabled
|
@_require_assets_feature_enabled
|
||||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||||
@@ -221,7 +164,20 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
|||||||
order=order,
|
order=order,
|
||||||
)
|
)
|
||||||
|
|
||||||
summaries = [_build_asset_response(item) for item in result.items]
|
summaries = [
|
||||||
|
schemas_out.AssetSummary(
|
||||||
|
id=item.ref.id,
|
||||||
|
name=item.ref.name,
|
||||||
|
asset_hash=item.asset.hash if item.asset else None,
|
||||||
|
size=int(item.asset.size_bytes) if item.asset else None,
|
||||||
|
mime_type=item.asset.mime_type if item.asset else None,
|
||||||
|
tags=item.tags,
|
||||||
|
created_at=item.ref.created_at,
|
||||||
|
updated_at=item.ref.updated_at,
|
||||||
|
last_access_time=item.ref.last_access_time,
|
||||||
|
)
|
||||||
|
for item in result.items
|
||||||
|
]
|
||||||
|
|
||||||
payload = schemas_out.AssetsList(
|
payload = schemas_out.AssetsList(
|
||||||
assets=summaries,
|
assets=summaries,
|
||||||
@@ -251,7 +207,18 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
|||||||
{"id": reference_id},
|
{"id": reference_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = _build_asset_response(result)
|
payload = schemas_out.AssetDetail(
|
||||||
|
id=result.ref.id,
|
||||||
|
name=result.ref.name,
|
||||||
|
asset_hash=result.asset.hash if result.asset else None,
|
||||||
|
size=int(result.asset.size_bytes) if result.asset else None,
|
||||||
|
mime_type=result.asset.mime_type if result.asset else None,
|
||||||
|
tags=result.tags,
|
||||||
|
user_metadata=result.ref.user_metadata or {},
|
||||||
|
preview_id=result.ref.preview_id,
|
||||||
|
created_at=result.ref.created_at,
|
||||||
|
last_access_time=result.ref.last_access_time,
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return _build_error_response(
|
return _build_error_response(
|
||||||
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
||||||
@@ -263,7 +230,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
|||||||
USER_MANAGER.get_request_user_id(request),
|
USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||||
@@ -345,31 +312,32 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
|||||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Derive name from hash if not provided
|
|
||||||
name = body.name
|
|
||||||
if name is None:
|
|
||||||
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
|
|
||||||
|
|
||||||
result = create_from_hash(
|
result = create_from_hash(
|
||||||
hash_str=body.hash,
|
hash_str=body.hash,
|
||||||
name=name,
|
name=body.name,
|
||||||
tags=body.tags,
|
tags=body.tags,
|
||||||
user_metadata=body.user_metadata,
|
user_metadata=body.user_metadata,
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
mime_type=body.mime_type,
|
|
||||||
preview_id=body.preview_id,
|
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
return _build_error_response(
|
return _build_error_response(
|
||||||
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
||||||
)
|
)
|
||||||
|
|
||||||
asset = _build_asset_response(result)
|
|
||||||
payload_out = schemas_out.AssetCreated(
|
payload_out = schemas_out.AssetCreated(
|
||||||
**asset.model_dump(),
|
id=result.ref.id,
|
||||||
|
name=result.ref.name,
|
||||||
|
asset_hash=result.asset.hash,
|
||||||
|
size=int(result.asset.size_bytes),
|
||||||
|
mime_type=result.asset.mime_type,
|
||||||
|
tags=result.tags,
|
||||||
|
user_metadata=result.ref.user_metadata or {},
|
||||||
|
preview_id=result.ref.preview_id,
|
||||||
|
created_at=result.ref.created_at,
|
||||||
|
last_access_time=result.ref.last_access_time,
|
||||||
created_new=result.created_new,
|
created_new=result.created_new,
|
||||||
)
|
)
|
||||||
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
|
return web.json_response(payload_out.model_dump(mode="json"), status=201)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post("/api/assets")
|
@ROUTES.post("/api/assets")
|
||||||
@@ -390,8 +358,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
"name": parsed.provided_name,
|
"name": parsed.provided_name,
|
||||||
"user_metadata": parsed.user_metadata_raw,
|
"user_metadata": parsed.user_metadata_raw,
|
||||||
"hash": parsed.provided_hash,
|
"hash": parsed.provided_hash,
|
||||||
"mime_type": parsed.provided_mime_type,
|
|
||||||
"preview_id": parsed.provided_preview_id,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except ValidationError as ve:
|
except ValidationError as ve:
|
||||||
@@ -420,8 +386,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
tags=spec.tags,
|
tags=spec.tags,
|
||||||
user_metadata=spec.user_metadata or {},
|
user_metadata=spec.user_metadata or {},
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
mime_type=spec.mime_type,
|
|
||||||
preview_id=spec.preview_id,
|
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
delete_temp_file_if_exists(parsed.tmp_path)
|
delete_temp_file_if_exists(parsed.tmp_path)
|
||||||
@@ -446,8 +410,6 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
client_filename=parsed.file_client_name,
|
client_filename=parsed.file_client_name,
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
expected_hash=spec.hash,
|
expected_hash=spec.hash,
|
||||||
mime_type=spec.mime_type,
|
|
||||||
preview_id=spec.preview_id,
|
|
||||||
)
|
)
|
||||||
except AssetValidationError as e:
|
except AssetValidationError as e:
|
||||||
delete_temp_file_if_exists(parsed.tmp_path)
|
delete_temp_file_if_exists(parsed.tmp_path)
|
||||||
@@ -466,13 +428,21 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
asset = _build_asset_response(result)
|
payload = schemas_out.AssetCreated(
|
||||||
payload_out = schemas_out.AssetCreated(
|
id=result.ref.id,
|
||||||
**asset.model_dump(),
|
name=result.ref.name,
|
||||||
|
asset_hash=result.asset.hash,
|
||||||
|
size=int(result.asset.size_bytes),
|
||||||
|
mime_type=result.asset.mime_type,
|
||||||
|
tags=result.tags,
|
||||||
|
user_metadata=result.ref.user_metadata or {},
|
||||||
|
preview_id=result.ref.preview_id,
|
||||||
|
created_at=result.ref.created_at,
|
||||||
|
last_access_time=result.ref.last_access_time,
|
||||||
created_new=result.created_new,
|
created_new=result.created_new,
|
||||||
)
|
)
|
||||||
status = 201 if result.created_new else 200
|
status = 201 if result.created_new else 200
|
||||||
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
|
return web.json_response(payload.model_dump(mode="json"), status=status)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
@@ -494,9 +464,15 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
|||||||
name=body.name,
|
name=body.name,
|
||||||
user_metadata=body.user_metadata,
|
user_metadata=body.user_metadata,
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
preview_id=body.preview_id,
|
|
||||||
)
|
)
|
||||||
payload = _build_asset_response(result)
|
payload = schemas_out.AssetUpdated(
|
||||||
|
id=result.ref.id,
|
||||||
|
name=result.ref.name,
|
||||||
|
asset_hash=result.asset.hash if result.asset else None,
|
||||||
|
tags=result.tags,
|
||||||
|
user_metadata=result.ref.user_metadata or {},
|
||||||
|
updated_at=result.ref.updated_at,
|
||||||
|
)
|
||||||
except PermissionError as pe:
|
except PermissionError as pe:
|
||||||
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
@@ -510,7 +486,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
|||||||
USER_MANAGER.get_request_user_id(request),
|
USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
@@ -579,7 +555,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
|||||||
payload = schemas_out.TagsList(
|
payload = schemas_out.TagsList(
|
||||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||||
)
|
)
|
||||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
return web.json_response(payload.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
@@ -627,7 +603,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
|||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
@@ -674,29 +650,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
|||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
return web.json_response(payload.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get("/api/assets/tags/refine")
|
|
||||||
@_require_assets_feature_enabled
|
|
||||||
async def get_tags_refine(request: web.Request) -> web.Response:
|
|
||||||
"""GET request to get tag histogram for filtered assets."""
|
|
||||||
query_dict = get_query_dict(request)
|
|
||||||
try:
|
|
||||||
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
|
|
||||||
except ValidationError as ve:
|
|
||||||
return _build_validation_error_response("INVALID_QUERY", ve)
|
|
||||||
|
|
||||||
tag_counts = list_tag_histogram(
|
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
|
||||||
include_tags=q.include_tags,
|
|
||||||
exclude_tags=q.exclude_tags,
|
|
||||||
name_contains=q.name_contains,
|
|
||||||
metadata_filter=q.metadata_filter,
|
|
||||||
limit=q.limit,
|
|
||||||
)
|
|
||||||
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
|
|
||||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post("/api/assets/seed")
|
@ROUTES.post("/api/assets/seed")
|
||||||
|
|||||||
@@ -45,8 +45,6 @@ class ParsedUpload:
|
|||||||
user_metadata_raw: str | None
|
user_metadata_raw: str | None
|
||||||
provided_hash: str | None
|
provided_hash: str | None
|
||||||
provided_hash_exists: bool | None
|
provided_hash_exists: bool | None
|
||||||
provided_mime_type: str | None = None
|
|
||||||
provided_preview_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ListAssetsQuery(BaseModel):
|
class ListAssetsQuery(BaseModel):
|
||||||
@@ -100,17 +98,11 @@ class ListAssetsQuery(BaseModel):
|
|||||||
class UpdateAssetBody(BaseModel):
|
class UpdateAssetBody(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
user_metadata: dict[str, Any] | None = None
|
user_metadata: dict[str, Any] | None = None
|
||||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_at_least_one_field(self):
|
def _validate_at_least_one_field(self):
|
||||||
if all(
|
if self.name is None and self.user_metadata is None:
|
||||||
v is None
|
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||||
for v in (self.name, self.user_metadata, self.preview_id)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Provide at least one of: name, user_metadata, preview_id."
|
|
||||||
)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@@ -118,11 +110,9 @@ class CreateFromHashBody(BaseModel):
|
|||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
hash: str
|
hash: str
|
||||||
name: str | None = None
|
name: str
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
mime_type: str | None = None
|
|
||||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
|
||||||
|
|
||||||
@field_validator("hash")
|
@field_validator("hash")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -148,44 +138,6 @@ class CreateFromHashBody(BaseModel):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class TagsRefineQuery(BaseModel):
|
|
||||||
include_tags: list[str] = Field(default_factory=list)
|
|
||||||
exclude_tags: list[str] = Field(default_factory=list)
|
|
||||||
name_contains: str | None = None
|
|
||||||
metadata_filter: dict[str, Any] | None = None
|
|
||||||
limit: conint(ge=1, le=1000) = 100
|
|
||||||
|
|
||||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _split_csv_tags(cls, v):
|
|
||||||
if v is None:
|
|
||||||
return []
|
|
||||||
if isinstance(v, str):
|
|
||||||
return [t.strip() for t in v.split(",") if t.strip()]
|
|
||||||
if isinstance(v, list):
|
|
||||||
out: list[str] = []
|
|
||||||
for item in v:
|
|
||||||
if isinstance(item, str):
|
|
||||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
|
||||||
return out
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("metadata_filter", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _parse_metadata_json(cls, v):
|
|
||||||
if v is None or isinstance(v, dict):
|
|
||||||
return v
|
|
||||||
if isinstance(v, str) and v.strip():
|
|
||||||
try:
|
|
||||||
parsed = json.loads(v)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
|
||||||
if not isinstance(parsed, dict):
|
|
||||||
raise ValueError("metadata_filter must be a JSON object")
|
|
||||||
return parsed
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class TagsListQuery(BaseModel):
|
class TagsListQuery(BaseModel):
|
||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
@@ -234,25 +186,21 @@ class TagsRemove(TagsAdd):
|
|||||||
class UploadAssetSpec(BaseModel):
|
class UploadAssetSpec(BaseModel):
|
||||||
"""Upload Asset operation.
|
"""Upload Asset operation.
|
||||||
|
|
||||||
- tags: optional list; if provided, first is root ('models'|'input'|'output');
|
- tags: ordered; first is root ('models'|'input'|'output');
|
||||||
if root == 'models', second must be a valid category
|
if root == 'models', second must be a valid category
|
||||||
- name: display name
|
- name: display name
|
||||||
- user_metadata: arbitrary JSON object (optional)
|
- user_metadata: arbitrary JSON object (optional)
|
||||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||||
- mime_type: optional MIME type override
|
|
||||||
- preview_id: optional asset_reference ID for preview
|
|
||||||
|
|
||||||
Files are stored using the content hash as filename stem.
|
Files are stored using the content hash as filename stem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(..., min_length=1)
|
||||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
hash: str | None = Field(default=None)
|
hash: str | None = Field(default=None)
|
||||||
mime_type: str | None = Field(default=None)
|
|
||||||
preview_id: str | None = Field(default=None) # references an asset_reference id
|
|
||||||
|
|
||||||
@field_validator("hash", mode="before")
|
@field_validator("hash", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -331,7 +279,7 @@ class UploadAssetSpec(BaseModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_order(self):
|
def _validate_order(self):
|
||||||
if not self.tags:
|
if not self.tags:
|
||||||
raise ValueError("at least one tag is required for uploads")
|
raise ValueError("tags must be provided and non-empty")
|
||||||
root = self.tags[0]
|
root = self.tags[0]
|
||||||
if root not in {"models", "input", "output"}:
|
if root not in {"models", "input", "output"}:
|
||||||
raise ValueError("first tag must be one of: models, input, output")
|
raise ValueError("first tag must be one of: models, input, output")
|
||||||
|
|||||||
@@ -4,10 +4,7 @@ from typing import Any
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||||
|
|
||||||
|
|
||||||
class Asset(BaseModel):
|
class AssetSummary(BaseModel):
|
||||||
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
|
|
||||||
``id`` here is the AssetReference id, not the content-addressed Asset id."""
|
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
asset_hash: str | None = None
|
asset_hash: str | None = None
|
||||||
@@ -15,14 +12,8 @@ class Asset(BaseModel):
|
|||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
preview_url: str | None = None
|
preview_url: str | None = None
|
||||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
created_at: datetime | None = None
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
updated_at: datetime | None = None
|
||||||
is_immutable: bool = False
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
job_id: str | None = None
|
|
||||||
prompt_id: str | None = None # deprecated: use job_id
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
last_access_time: datetime | None = None
|
last_access_time: datetime | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
@@ -32,16 +23,50 @@ class Asset(BaseModel):
|
|||||||
return v.isoformat() if v else None
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
class AssetCreated(Asset):
|
|
||||||
created_new: bool
|
|
||||||
|
|
||||||
|
|
||||||
class AssetsList(BaseModel):
|
class AssetsList(BaseModel):
|
||||||
assets: list[Asset]
|
assets: list[AssetSummary]
|
||||||
total: int
|
total: int
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AssetUpdated(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
asset_hash: str | None = None
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
updated_at: datetime | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@field_serializer("updated_at")
|
||||||
|
def _serialize_updated_at(self, v: datetime | None, _info):
|
||||||
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetDetail(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
asset_hash: str | None = None
|
||||||
|
size: int | None = None
|
||||||
|
mime_type: str | None = None
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
preview_id: str | None = None
|
||||||
|
created_at: datetime | None = None
|
||||||
|
last_access_time: datetime | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@field_serializer("created_at", "last_access_time")
|
||||||
|
def _serialize_datetime(self, v: datetime | None, _info):
|
||||||
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetCreated(AssetDetail):
|
||||||
|
created_new: bool
|
||||||
|
|
||||||
|
|
||||||
class TagUsage(BaseModel):
|
class TagUsage(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
@@ -66,7 +91,3 @@ class TagsRemove(BaseModel):
|
|||||||
removed: list[str] = Field(default_factory=list)
|
removed: list[str] = Field(default_factory=list)
|
||||||
not_present: list[str] = Field(default_factory=list)
|
not_present: list[str] = Field(default_factory=list)
|
||||||
total_tags: list[str] = Field(default_factory=list)
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class TagHistogram(BaseModel):
|
|
||||||
tag_counts: dict[str, int]
|
|
||||||
|
|||||||
@@ -52,8 +52,6 @@ async def parse_multipart_upload(
|
|||||||
user_metadata_raw: str | None = None
|
user_metadata_raw: str | None = None
|
||||||
provided_hash: str | None = None
|
provided_hash: str | None = None
|
||||||
provided_hash_exists: bool | None = None
|
provided_hash_exists: bool | None = None
|
||||||
provided_mime_type: str | None = None
|
|
||||||
provided_preview_id: str | None = None
|
|
||||||
|
|
||||||
file_written = 0
|
file_written = 0
|
||||||
tmp_path: str | None = None
|
tmp_path: str | None = None
|
||||||
@@ -130,16 +128,6 @@ async def parse_multipart_upload(
|
|||||||
provided_name = (await field.text()) or None
|
provided_name = (await field.text()) or None
|
||||||
elif fname == "user_metadata":
|
elif fname == "user_metadata":
|
||||||
user_metadata_raw = (await field.text()) or None
|
user_metadata_raw = (await field.text()) or None
|
||||||
elif fname == "id":
|
|
||||||
raise UploadError(
|
|
||||||
400,
|
|
||||||
"UNSUPPORTED_FIELD",
|
|
||||||
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
|
|
||||||
)
|
|
||||||
elif fname == "mime_type":
|
|
||||||
provided_mime_type = ((await field.text()) or "").strip() or None
|
|
||||||
elif fname == "preview_id":
|
|
||||||
provided_preview_id = ((await field.text()) or "").strip() or None
|
|
||||||
|
|
||||||
if not file_present and not (provided_hash and provided_hash_exists):
|
if not file_present and not (provided_hash and provided_hash_exists):
|
||||||
raise UploadError(
|
raise UploadError(
|
||||||
@@ -164,8 +152,6 @@ async def parse_multipart_upload(
|
|||||||
user_metadata_raw=user_metadata_raw,
|
user_metadata_raw=user_metadata_raw,
|
||||||
provided_hash=provided_hash,
|
provided_hash=provided_hash,
|
||||||
provided_hash_exists=provided_hash_exists,
|
provided_hash_exists=provided_hash_exists,
|
||||||
provided_mime_type=provided_mime_type,
|
|
||||||
provided_preview_id=provided_preview_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,13 @@ class Asset(Base):
|
|||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# preview_id on AssetReference is a self-referential FK to asset_references.id
|
preview_of: Mapped[list[AssetReference]] = relationship(
|
||||||
|
"AssetReference",
|
||||||
|
back_populates="preview_asset",
|
||||||
|
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
|
||||||
|
foreign_keys=lambda: [AssetReference.preview_id],
|
||||||
|
viewonly=True,
|
||||||
|
)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("uq_assets_hash", "hash", unique=True),
|
Index("uq_assets_hash", "hash", unique=True),
|
||||||
@@ -85,15 +91,11 @@ class AssetReference(Base):
|
|||||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
preview_id: Mapped[str | None] = mapped_column(
|
preview_id: Mapped[str | None] = mapped_column(
|
||||||
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
|
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
||||||
)
|
)
|
||||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||||
JSON(none_as_null=True)
|
JSON(none_as_null=True)
|
||||||
)
|
)
|
||||||
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
|
||||||
JSON(none_as_null=True), nullable=True, default=None
|
|
||||||
)
|
|
||||||
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||||
)
|
)
|
||||||
@@ -113,10 +115,10 @@ class AssetReference(Base):
|
|||||||
foreign_keys=[asset_id],
|
foreign_keys=[asset_id],
|
||||||
lazy="selectin",
|
lazy="selectin",
|
||||||
)
|
)
|
||||||
preview_ref: Mapped[AssetReference | None] = relationship(
|
preview_asset: Mapped[Asset | None] = relationship(
|
||||||
"AssetReference",
|
"Asset",
|
||||||
|
back_populates="preview_of",
|
||||||
foreign_keys=[preview_id],
|
foreign_keys=[preview_id],
|
||||||
remote_side=lambda: [AssetReference.id],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||||
@@ -150,7 +152,6 @@ class AssetReference(Base):
|
|||||||
Index("ix_asset_references_created_at", "created_at"),
|
Index("ix_asset_references_created_at", "created_at"),
|
||||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||||
Index("ix_asset_references_preview_id", "preview_id"),
|
|
||||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||||
CheckConstraint(
|
CheckConstraint(
|
||||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||||
@@ -191,10 +192,6 @@ class AssetReferenceMeta(Base):
|
|||||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||||
CheckConstraint(
|
|
||||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
|
||||||
name="has_value",
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,21 +31,16 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
get_unreferenced_unhashed_asset_ids,
|
get_unreferenced_unhashed_asset_ids,
|
||||||
insert_reference,
|
insert_reference,
|
||||||
list_all_file_paths_by_asset_id,
|
|
||||||
list_references_by_asset_id,
|
list_references_by_asset_id,
|
||||||
list_references_page,
|
list_references_page,
|
||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
rebuild_metadata_projection,
|
|
||||||
reference_exists,
|
|
||||||
reference_exists_for_asset_id,
|
reference_exists_for_asset_id,
|
||||||
restore_references_by_paths,
|
restore_references_by_paths,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_preview,
|
set_reference_preview,
|
||||||
set_reference_system_metadata,
|
|
||||||
soft_delete_reference_by_id,
|
soft_delete_reference_by_id,
|
||||||
update_reference_access_time,
|
update_reference_access_time,
|
||||||
update_reference_name,
|
update_reference_name,
|
||||||
update_is_missing_by_asset_id,
|
|
||||||
update_reference_timestamps,
|
update_reference_timestamps,
|
||||||
update_reference_updated_at,
|
update_reference_updated_at,
|
||||||
upsert_reference,
|
upsert_reference,
|
||||||
@@ -59,7 +54,6 @@ from app.assets.database.queries.tags import (
|
|||||||
bulk_insert_tags_and_meta,
|
bulk_insert_tags_and_meta,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
get_reference_tags,
|
get_reference_tags,
|
||||||
list_tag_counts_for_filtered_assets,
|
|
||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
remove_tags_from_reference,
|
remove_tags_from_reference,
|
||||||
@@ -103,26 +97,20 @@ __all__ = [
|
|||||||
"get_unenriched_references",
|
"get_unenriched_references",
|
||||||
"get_unreferenced_unhashed_asset_ids",
|
"get_unreferenced_unhashed_asset_ids",
|
||||||
"insert_reference",
|
"insert_reference",
|
||||||
"list_all_file_paths_by_asset_id",
|
|
||||||
"list_references_by_asset_id",
|
"list_references_by_asset_id",
|
||||||
"list_references_page",
|
"list_references_page",
|
||||||
"list_tag_counts_for_filtered_assets",
|
|
||||||
"list_tags_with_usage",
|
"list_tags_with_usage",
|
||||||
"mark_references_missing_outside_prefixes",
|
"mark_references_missing_outside_prefixes",
|
||||||
"reassign_asset_references",
|
"reassign_asset_references",
|
||||||
"rebuild_metadata_projection",
|
|
||||||
"reference_exists",
|
|
||||||
"reference_exists_for_asset_id",
|
"reference_exists_for_asset_id",
|
||||||
"remove_missing_tag_for_asset_id",
|
"remove_missing_tag_for_asset_id",
|
||||||
"remove_tags_from_reference",
|
"remove_tags_from_reference",
|
||||||
"restore_references_by_paths",
|
"restore_references_by_paths",
|
||||||
"set_reference_metadata",
|
"set_reference_metadata",
|
||||||
"set_reference_preview",
|
"set_reference_preview",
|
||||||
"set_reference_system_metadata",
|
|
||||||
"soft_delete_reference_by_id",
|
"soft_delete_reference_by_id",
|
||||||
"set_reference_tags",
|
"set_reference_tags",
|
||||||
"update_asset_hash_and_mime",
|
"update_asset_hash_and_mime",
|
||||||
"update_is_missing_by_asset_id",
|
|
||||||
"update_reference_access_time",
|
"update_reference_access_time",
|
||||||
"update_reference_name",
|
"update_reference_name",
|
||||||
"update_reference_timestamps",
|
"update_reference_timestamps",
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def upsert_asset(
|
|||||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||||
asset.size_bytes = int(size_bytes)
|
asset.size_bytes = int(size_bytes)
|
||||||
changed = True
|
changed = True
|
||||||
if mime_type and not asset.mime_type:
|
if mime_type and asset.mime_type != mime_type:
|
||||||
asset.mime_type = mime_type
|
asset.mime_type = mime_type
|
||||||
changed = True
|
changed = True
|
||||||
if changed:
|
if changed:
|
||||||
@@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
|
|||||||
return False
|
return False
|
||||||
if asset_hash is not None:
|
if asset_hash is not None:
|
||||||
asset.hash = asset_hash
|
asset.hash = asset_hash
|
||||||
if mime_type is not None and not asset.mime_type:
|
if mime_type is not None:
|
||||||
asset.mime_type = mime_type
|
asset.mime_type = mime_type
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from decimal import Decimal
|
|||||||
from typing import NamedTuple, Sequence
|
from typing import NamedTuple, Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, exists, select
|
||||||
from sqlalchemy.dialects import sqlite
|
from sqlalchemy.dialects import sqlite
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session, noload
|
from sqlalchemy.orm import Session, noload
|
||||||
@@ -24,14 +24,12 @@ from app.assets.database.models import (
|
|||||||
)
|
)
|
||||||
from app.assets.database.queries.common import (
|
from app.assets.database.queries.common import (
|
||||||
MAX_BIND_PARAMS,
|
MAX_BIND_PARAMS,
|
||||||
apply_metadata_filter,
|
|
||||||
apply_tag_filters,
|
|
||||||
build_prefix_like_conditions,
|
build_prefix_like_conditions,
|
||||||
build_visible_owner_clause,
|
build_visible_owner_clause,
|
||||||
calculate_rows_per_statement,
|
calculate_rows_per_statement,
|
||||||
iter_chunks,
|
iter_chunks,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import escape_sql_like_string, get_utc_now
|
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||||
|
|
||||||
|
|
||||||
def _check_is_scalar(v):
|
def _check_is_scalar(v):
|
||||||
@@ -46,6 +44,15 @@ def _check_is_scalar(v):
|
|||||||
|
|
||||||
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||||
"""Convert a scalar value to a typed projection row."""
|
"""Convert a scalar value to a typed projection row."""
|
||||||
|
if value is None:
|
||||||
|
return {
|
||||||
|
"key": key,
|
||||||
|
"ordinal": ordinal,
|
||||||
|
"val_str": None,
|
||||||
|
"val_num": None,
|
||||||
|
"val_bool": None,
|
||||||
|
"val_json": None,
|
||||||
|
}
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
||||||
if isinstance(value, (int, float, Decimal)):
|
if isinstance(value, (int, float, Decimal)):
|
||||||
@@ -59,19 +66,96 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
|||||||
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
||||||
"""Turn a metadata key/value into typed projection rows."""
|
"""Turn a metadata key/value into typed projection rows."""
|
||||||
if value is None:
|
if value is None:
|
||||||
return []
|
return [_scalar_to_row(key, 0, None)]
|
||||||
|
|
||||||
if _check_is_scalar(value):
|
if _check_is_scalar(value):
|
||||||
return [_scalar_to_row(key, 0, value)]
|
return [_scalar_to_row(key, 0, value)]
|
||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
if all(_check_is_scalar(x) for x in value):
|
if all(_check_is_scalar(x) for x in value):
|
||||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
|
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
|
||||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
|
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
|
||||||
|
|
||||||
return [{"key": key, "ordinal": 0, "val_json": value}]
|
return [{"key": key, "ordinal": 0, "val_json": value}]
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_tag_filters(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||||
|
include_tags = normalize_tags(include_tags)
|
||||||
|
exclude_tags = normalize_tags(exclude_tags)
|
||||||
|
|
||||||
|
if include_tags:
|
||||||
|
for tag_name in include_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists().where(
|
||||||
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||||
|
& (AssetReferenceTag.tag_name == tag_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists().where(
|
||||||
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||||
|
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_metadata_filter(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""Apply filters using asset_reference_meta projection table."""
|
||||||
|
if not metadata_filter:
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||||
|
return sa.exists().where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||||
|
AssetReferenceMeta.key == key,
|
||||||
|
*preds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||||
|
if value is None:
|
||||||
|
no_row_for_key = sa.not_(
|
||||||
|
sa.exists().where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||||
|
AssetReferenceMeta.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
null_row = _exists_for_pred(
|
||||||
|
key,
|
||||||
|
AssetReferenceMeta.val_json.is_(None),
|
||||||
|
AssetReferenceMeta.val_str.is_(None),
|
||||||
|
AssetReferenceMeta.val_num.is_(None),
|
||||||
|
AssetReferenceMeta.val_bool.is_(None),
|
||||||
|
)
|
||||||
|
return sa.or_(no_row_for_key, null_row)
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||||
|
if isinstance(value, (int, float, Decimal)):
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||||
|
|
||||||
|
for k, v in metadata_filter.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||||
|
if ors:
|
||||||
|
stmt = stmt.where(sa.or_(*ors))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def get_reference_by_id(
|
def get_reference_by_id(
|
||||||
@@ -128,21 +212,6 @@ def reference_exists_for_asset_id(
|
|||||||
return session.execute(q).first() is not None
|
return session.execute(q).first() is not None
|
||||||
|
|
||||||
|
|
||||||
def reference_exists(
|
|
||||||
session: Session,
|
|
||||||
reference_id: str,
|
|
||||||
) -> bool:
|
|
||||||
"""Return True if a reference with the given ID exists (not soft-deleted)."""
|
|
||||||
q = (
|
|
||||||
select(sa.literal(True))
|
|
||||||
.select_from(AssetReference)
|
|
||||||
.where(AssetReference.id == reference_id)
|
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
return session.execute(q).first() is not None
|
|
||||||
|
|
||||||
|
|
||||||
def insert_reference(
|
def insert_reference(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
@@ -267,8 +336,8 @@ def list_references_page(
|
|||||||
escaped, esc = escape_sql_like_string(name_contains)
|
escaped, esc = escape_sql_like_string(name_contains)
|
||||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
base = apply_metadata_filter(base, metadata_filter)
|
base = _apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
sort = (sort or "created_at").lower()
|
sort = (sort or "created_at").lower()
|
||||||
order = (order or "desc").lower()
|
order = (order or "desc").lower()
|
||||||
@@ -297,8 +366,8 @@ def list_references_page(
|
|||||||
count_stmt = count_stmt.where(
|
count_stmt = count_stmt.where(
|
||||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||||
)
|
)
|
||||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
total = int(session.execute(count_stmt).scalar_one() or 0)
|
total = int(session.execute(count_stmt).scalar_one() or 0)
|
||||||
refs = session.execute(base).unique().scalars().all()
|
refs = session.execute(base).unique().scalars().all()
|
||||||
@@ -310,7 +379,7 @@ def list_references_page(
|
|||||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||||
.order_by(AssetReferenceTag.tag_name.asc())
|
.order_by(AssetReferenceTag.added_at)
|
||||||
)
|
)
|
||||||
for ref_id, tag_name in rows.all():
|
for ref_id, tag_name in rows.all():
|
||||||
tag_map[ref_id].append(tag_name)
|
tag_map[ref_id].append(tag_name)
|
||||||
@@ -423,42 +492,6 @@ def update_reference_updated_at(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
|
|
||||||
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
|
|
||||||
|
|
||||||
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
|
|
||||||
override system keys of the same name.
|
|
||||||
"""
|
|
||||||
session.execute(
|
|
||||||
delete(AssetReferenceMeta).where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == ref.id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
|
|
||||||
if not merged:
|
|
||||||
return
|
|
||||||
|
|
||||||
rows: list[AssetReferenceMeta] = []
|
|
||||||
for k, v in merged.items():
|
|
||||||
for r in convert_metadata_to_rows(k, v):
|
|
||||||
rows.append(
|
|
||||||
AssetReferenceMeta(
|
|
||||||
asset_reference_id=ref.id,
|
|
||||||
key=r["key"],
|
|
||||||
ordinal=int(r["ordinal"]),
|
|
||||||
val_str=r.get("val_str"),
|
|
||||||
val_num=r.get("val_num"),
|
|
||||||
val_bool=r.get("val_bool"),
|
|
||||||
val_json=r.get("val_json"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if rows:
|
|
||||||
session.add_all(rows)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def set_reference_metadata(
|
def set_reference_metadata(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
@@ -472,24 +505,33 @@ def set_reference_metadata(
|
|||||||
ref.updated_at = get_utc_now()
|
ref.updated_at = get_utc_now()
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
rebuild_metadata_projection(session, ref)
|
session.execute(
|
||||||
|
delete(AssetReferenceMeta).where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == reference_id
|
||||||
def set_reference_system_metadata(
|
)
|
||||||
session: Session,
|
)
|
||||||
reference_id: str,
|
|
||||||
system_metadata: dict | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Set system_metadata on a reference and rebuild the merged projection."""
|
|
||||||
ref = session.get(AssetReference, reference_id)
|
|
||||||
if not ref:
|
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
|
||||||
|
|
||||||
ref.system_metadata = system_metadata or {}
|
|
||||||
ref.updated_at = get_utc_now()
|
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
rebuild_metadata_projection(session, ref)
|
if not user_metadata:
|
||||||
|
return
|
||||||
|
|
||||||
|
rows: list[AssetReferenceMeta] = []
|
||||||
|
for k, v in user_metadata.items():
|
||||||
|
for r in convert_metadata_to_rows(k, v):
|
||||||
|
rows.append(
|
||||||
|
AssetReferenceMeta(
|
||||||
|
asset_reference_id=reference_id,
|
||||||
|
key=r["key"],
|
||||||
|
ordinal=int(r["ordinal"]),
|
||||||
|
val_str=r.get("val_str"),
|
||||||
|
val_num=r.get("val_num"),
|
||||||
|
val_bool=r.get("val_bool"),
|
||||||
|
val_json=r.get("val_json"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rows:
|
||||||
|
session.add_all(rows)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
|
||||||
def delete_reference_by_id(
|
def delete_reference_by_id(
|
||||||
@@ -529,19 +571,19 @@ def soft_delete_reference_by_id(
|
|||||||
def set_reference_preview(
|
def set_reference_preview(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
preview_reference_id: str | None = None,
|
preview_asset_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||||
ref = session.get(AssetReference, reference_id)
|
ref = session.get(AssetReference, reference_id)
|
||||||
if not ref:
|
if not ref:
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
raise ValueError(f"AssetReference {reference_id} not found")
|
||||||
|
|
||||||
if preview_reference_id is None:
|
if preview_asset_id is None:
|
||||||
ref.preview_id = None
|
ref.preview_id = None
|
||||||
else:
|
else:
|
||||||
if not session.get(AssetReference, preview_reference_id):
|
if not session.get(Asset, preview_asset_id):
|
||||||
raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
|
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||||
ref.preview_id = preview_reference_id
|
ref.preview_id = preview_asset_id
|
||||||
|
|
||||||
ref.updated_at = get_utc_now()
|
ref.updated_at = get_utc_now()
|
||||||
session.flush()
|
session.flush()
|
||||||
@@ -567,8 +609,6 @@ def list_references_by_asset_id(
|
|||||||
session.execute(
|
session.execute(
|
||||||
select(AssetReference)
|
select(AssetReference)
|
||||||
.where(AssetReference.asset_id == asset_id)
|
.where(AssetReference.asset_id == asset_id)
|
||||||
.where(AssetReference.is_missing == False) # noqa: E712
|
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
|
||||||
.order_by(AssetReference.id.asc())
|
.order_by(AssetReference.id.asc())
|
||||||
)
|
)
|
||||||
.scalars()
|
.scalars()
|
||||||
@@ -576,25 +616,6 @@ def list_references_by_asset_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def list_all_file_paths_by_asset_id(
|
|
||||||
session: Session,
|
|
||||||
asset_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Return every file_path for an asset, including soft-deleted/missing refs.
|
|
||||||
|
|
||||||
Used for orphan cleanup where all on-disk files must be removed.
|
|
||||||
"""
|
|
||||||
return list(
|
|
||||||
session.execute(
|
|
||||||
select(AssetReference.file_path)
|
|
||||||
.where(AssetReference.asset_id == asset_id)
|
|
||||||
.where(AssetReference.file_path.isnot(None))
|
|
||||||
)
|
|
||||||
.scalars()
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def upsert_reference(
|
def upsert_reference(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
@@ -834,22 +855,6 @@ def bulk_update_is_missing(
|
|||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
||||||
def update_is_missing_by_asset_id(
|
|
||||||
session: Session, asset_id: str, value: bool
|
|
||||||
) -> int:
|
|
||||||
"""Set is_missing flag for ALL references belonging to an asset.
|
|
||||||
|
|
||||||
Returns: Number of rows updated
|
|
||||||
"""
|
|
||||||
result = session.execute(
|
|
||||||
sa.update(AssetReference)
|
|
||||||
.where(AssetReference.asset_id == asset_id)
|
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
|
||||||
.values(is_missing=value)
|
|
||||||
)
|
|
||||||
return result.rowcount
|
|
||||||
|
|
||||||
|
|
||||||
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
||||||
"""Delete references by their IDs.
|
"""Delete references by their IDs.
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
"""Shared utilities for database query modules."""
|
"""Shared utilities for database query modules."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from decimal import Decimal
|
from typing import Iterable
|
||||||
from typing import Iterable, Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import exists
|
|
||||||
|
|
||||||
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
|
from app.assets.database.models import AssetReference
|
||||||
from app.assets.helpers import escape_sql_like_string, normalize_tags
|
from app.assets.helpers import escape_sql_like_string
|
||||||
|
|
||||||
MAX_BIND_PARAMS = 800
|
MAX_BIND_PARAMS = 800
|
||||||
|
|
||||||
@@ -54,74 +52,3 @@ def build_prefix_like_conditions(
|
|||||||
escaped, esc = escape_sql_like_string(base)
|
escaped, esc = escape_sql_like_string(base)
|
||||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
|
||||||
def apply_tag_filters(
|
|
||||||
stmt: sa.sql.Select,
|
|
||||||
include_tags: Sequence[str] | None = None,
|
|
||||||
exclude_tags: Sequence[str] | None = None,
|
|
||||||
) -> sa.sql.Select:
|
|
||||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
|
||||||
include_tags = normalize_tags(include_tags)
|
|
||||||
exclude_tags = normalize_tags(exclude_tags)
|
|
||||||
|
|
||||||
if include_tags:
|
|
||||||
for tag_name in include_tags:
|
|
||||||
stmt = stmt.where(
|
|
||||||
exists().where(
|
|
||||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
||||||
& (AssetReferenceTag.tag_name == tag_name)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if exclude_tags:
|
|
||||||
stmt = stmt.where(
|
|
||||||
~exists().where(
|
|
||||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
||||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def apply_metadata_filter(
|
|
||||||
stmt: sa.sql.Select,
|
|
||||||
metadata_filter: dict | None = None,
|
|
||||||
) -> sa.sql.Select:
|
|
||||||
"""Apply filters using asset_reference_meta projection table."""
|
|
||||||
if not metadata_filter:
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
|
||||||
return sa.exists().where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
||||||
AssetReferenceMeta.key == key,
|
|
||||||
*preds,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
|
||||||
if value is None:
|
|
||||||
return sa.not_(
|
|
||||||
sa.exists().where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
||||||
AssetReferenceMeta.key == key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
|
||||||
if isinstance(value, (int, float, Decimal)):
|
|
||||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
|
||||||
if isinstance(value, str):
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
|
||||||
|
|
||||||
for k, v in metadata_filter.items():
|
|
||||||
if isinstance(v, list):
|
|
||||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
|
||||||
if ors:
|
|
||||||
stmt = stmt.where(sa.or_(*ors))
|
|
||||||
else:
|
|
||||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
|
||||||
return stmt
|
|
||||||
|
|||||||
@@ -8,15 +8,12 @@ from sqlalchemy.exc import IntegrityError
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import (
|
from app.assets.database.models import (
|
||||||
Asset,
|
|
||||||
AssetReference,
|
AssetReference,
|
||||||
AssetReferenceMeta,
|
AssetReferenceMeta,
|
||||||
AssetReferenceTag,
|
AssetReferenceTag,
|
||||||
Tag,
|
Tag,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries.common import (
|
from app.assets.database.queries.common import (
|
||||||
apply_metadata_filter,
|
|
||||||
apply_tag_filters,
|
|
||||||
build_visible_owner_clause,
|
build_visible_owner_clause,
|
||||||
iter_row_chunks,
|
iter_row_chunks,
|
||||||
)
|
)
|
||||||
@@ -75,9 +72,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
|||||||
tag_name
|
tag_name
|
||||||
for (tag_name,) in (
|
for (tag_name,) in (
|
||||||
session.execute(
|
session.execute(
|
||||||
select(AssetReferenceTag.tag_name)
|
select(AssetReferenceTag.tag_name).where(
|
||||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
AssetReferenceTag.asset_reference_id == reference_id
|
||||||
.order_by(AssetReferenceTag.tag_name.asc())
|
)
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
]
|
]
|
||||||
@@ -120,7 +117,7 @@ def set_reference_tags(
|
|||||||
)
|
)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
|
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
||||||
|
|
||||||
|
|
||||||
def add_tags_to_reference(
|
def add_tags_to_reference(
|
||||||
@@ -275,12 +272,6 @@ def list_tags_with_usage(
|
|||||||
.select_from(AssetReferenceTag)
|
.select_from(AssetReferenceTag)
|
||||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||||
.where(build_visible_owner_clause(owner_id))
|
.where(build_visible_owner_clause(owner_id))
|
||||||
.where(
|
|
||||||
sa.or_(
|
|
||||||
AssetReference.is_missing == False, # noqa: E712
|
|
||||||
AssetReferenceTag.tag_name == "missing",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
.group_by(AssetReferenceTag.tag_name)
|
.group_by(AssetReferenceTag.tag_name)
|
||||||
.subquery()
|
.subquery()
|
||||||
@@ -317,12 +308,6 @@ def list_tags_with_usage(
|
|||||||
select(AssetReferenceTag.tag_name)
|
select(AssetReferenceTag.tag_name)
|
||||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||||
.where(build_visible_owner_clause(owner_id))
|
.where(build_visible_owner_clause(owner_id))
|
||||||
.where(
|
|
||||||
sa.or_(
|
|
||||||
AssetReference.is_missing == False, # noqa: E712
|
|
||||||
AssetReferenceTag.tag_name == "missing",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
.group_by(AssetReferenceTag.tag_name)
|
.group_by(AssetReferenceTag.tag_name)
|
||||||
)
|
)
|
||||||
@@ -335,53 +320,6 @@ def list_tags_with_usage(
|
|||||||
return rows_norm, int(total or 0)
|
return rows_norm, int(total or 0)
|
||||||
|
|
||||||
|
|
||||||
def list_tag_counts_for_filtered_assets(
|
|
||||||
session: Session,
|
|
||||||
owner_id: str = "",
|
|
||||||
include_tags: Sequence[str] | None = None,
|
|
||||||
exclude_tags: Sequence[str] | None = None,
|
|
||||||
name_contains: str | None = None,
|
|
||||||
metadata_filter: dict | None = None,
|
|
||||||
limit: int = 100,
|
|
||||||
) -> dict[str, int]:
|
|
||||||
"""Return tag counts for assets matching the given filters.
|
|
||||||
|
|
||||||
Uses the same filtering logic as list_references_page but returns
|
|
||||||
{tag_name: count} instead of paginated references.
|
|
||||||
"""
|
|
||||||
# Build a subquery of matching reference IDs
|
|
||||||
ref_sq = (
|
|
||||||
select(AssetReference.id)
|
|
||||||
.join(Asset, Asset.id == AssetReference.asset_id)
|
|
||||||
.where(build_visible_owner_clause(owner_id))
|
|
||||||
.where(AssetReference.is_missing == False) # noqa: E712
|
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
|
||||||
)
|
|
||||||
|
|
||||||
if name_contains:
|
|
||||||
escaped, esc = escape_sql_like_string(name_contains)
|
|
||||||
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
|
||||||
|
|
||||||
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
|
|
||||||
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
|
|
||||||
ref_sq = ref_sq.subquery()
|
|
||||||
|
|
||||||
# Count tags across those references
|
|
||||||
q = (
|
|
||||||
select(
|
|
||||||
AssetReferenceTag.tag_name,
|
|
||||||
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
|
||||||
)
|
|
||||||
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
|
|
||||||
.group_by(AssetReferenceTag.tag_name)
|
|
||||||
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
|
|
||||||
rows = session.execute(q).all()
|
|
||||||
return {tag_name: int(cnt) for tag_name, cnt in rows}
|
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_tags_and_meta(
|
def bulk_insert_tags_and_meta(
|
||||||
session: Session,
|
session: Session,
|
||||||
tag_rows: list[dict],
|
tag_rows: list[dict],
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from app.assets.database.queries import (
|
|||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
set_reference_system_metadata,
|
set_reference_metadata,
|
||||||
update_asset_hash_and_mime,
|
update_asset_hash_and_mime,
|
||||||
)
|
)
|
||||||
from app.assets.services.bulk_ingest import (
|
from app.assets.services.bulk_ingest import (
|
||||||
@@ -490,8 +490,8 @@ def enrich_asset(
|
|||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
if extract_metadata and metadata:
|
if extract_metadata and metadata:
|
||||||
system_metadata = metadata.to_user_metadata()
|
user_metadata = metadata.to_user_metadata()
|
||||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
set_reference_metadata(session, reference_id, user_metadata)
|
||||||
|
|
||||||
if full_hash:
|
if full_hash:
|
||||||
existing = get_asset_by_hash(session, full_hash)
|
existing = get_asset_by_hash(session, full_hash)
|
||||||
|
|||||||
@@ -16,12 +16,10 @@ from app.assets.database.queries import (
|
|||||||
get_reference_by_id,
|
get_reference_by_id,
|
||||||
get_reference_with_owner_check,
|
get_reference_with_owner_check,
|
||||||
list_references_page,
|
list_references_page,
|
||||||
list_all_file_paths_by_asset_id,
|
|
||||||
list_references_by_asset_id,
|
list_references_by_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_preview,
|
set_reference_preview,
|
||||||
set_reference_tags,
|
set_reference_tags,
|
||||||
update_asset_hash_and_mime,
|
|
||||||
update_reference_access_time,
|
update_reference_access_time,
|
||||||
update_reference_name,
|
update_reference_name,
|
||||||
update_reference_updated_at,
|
update_reference_updated_at,
|
||||||
@@ -69,8 +67,6 @@ def update_asset_metadata(
|
|||||||
user_metadata: UserMetadata = None,
|
user_metadata: UserMetadata = None,
|
||||||
tag_origin: str = "manual",
|
tag_origin: str = "manual",
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
mime_type: str | None = None,
|
|
||||||
preview_id: str | None = None,
|
|
||||||
) -> AssetDetailResult:
|
) -> AssetDetailResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||||
@@ -107,21 +103,6 @@ def update_asset_metadata(
|
|||||||
)
|
)
|
||||||
touched = True
|
touched = True
|
||||||
|
|
||||||
if mime_type is not None:
|
|
||||||
updated = update_asset_hash_and_mime(
|
|
||||||
session, asset_id=ref.asset_id, mime_type=mime_type
|
|
||||||
)
|
|
||||||
if updated:
|
|
||||||
touched = True
|
|
||||||
|
|
||||||
if preview_id is not None:
|
|
||||||
set_reference_preview(
|
|
||||||
session,
|
|
||||||
reference_id=reference_id,
|
|
||||||
preview_reference_id=preview_id,
|
|
||||||
)
|
|
||||||
touched = True
|
|
||||||
|
|
||||||
if touched and user_metadata is None:
|
if touched and user_metadata is None:
|
||||||
update_reference_updated_at(session, reference_id=reference_id)
|
update_reference_updated_at(session, reference_id=reference_id)
|
||||||
|
|
||||||
@@ -178,9 +159,11 @@ def delete_asset_reference(
|
|||||||
session.commit()
|
session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Orphaned asset - gather ALL file paths (including
|
# Orphaned asset - delete it and its files
|
||||||
# soft-deleted / missing refs) so their on-disk files get cleaned up.
|
refs = list_references_by_asset_id(session, asset_id=asset_id)
|
||||||
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
|
file_paths = [
|
||||||
|
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
|
||||||
|
]
|
||||||
# Also include the just-deleted file path
|
# Also include the just-deleted file path
|
||||||
if file_path:
|
if file_path:
|
||||||
file_paths.append(file_path)
|
file_paths.append(file_path)
|
||||||
@@ -202,7 +185,7 @@ def delete_asset_reference(
|
|||||||
|
|
||||||
def set_asset_preview(
|
def set_asset_preview(
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
preview_reference_id: str | None = None,
|
preview_asset_id: str | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
) -> AssetDetailResult:
|
) -> AssetDetailResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
@@ -211,7 +194,7 @@ def set_asset_preview(
|
|||||||
set_reference_preview(
|
set_reference_preview(
|
||||||
session,
|
session,
|
||||||
reference_id=reference_id,
|
reference_id=reference_id,
|
||||||
preview_reference_id=preview_reference_id,
|
preview_asset_id=preview_asset_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = fetch_reference_asset_and_tags(
|
result = fetch_reference_asset_and_tags(
|
||||||
@@ -280,47 +263,6 @@ def list_assets_page(
|
|||||||
return ListAssetsResult(items=items, total=total)
|
return ListAssetsResult(items=items, total=total)
|
||||||
|
|
||||||
|
|
||||||
def resolve_hash_to_path(
|
|
||||||
asset_hash: str,
|
|
||||||
owner_id: str = "",
|
|
||||||
) -> DownloadResolutionResult | None:
|
|
||||||
"""Resolve a blake3 hash to an on-disk file path.
|
|
||||||
|
|
||||||
Only references visible to *owner_id* are considered (owner-less
|
|
||||||
references are always visible).
|
|
||||||
|
|
||||||
Returns a DownloadResolutionResult with abs_path, content_type, and
|
|
||||||
download_name, or None if no asset or live path is found.
|
|
||||||
"""
|
|
||||||
with create_session() as session:
|
|
||||||
asset = queries_get_asset_by_hash(session, asset_hash)
|
|
||||||
if not asset:
|
|
||||||
return None
|
|
||||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
|
||||||
visible = [
|
|
||||||
r for r in refs
|
|
||||||
if r.owner_id == "" or r.owner_id == owner_id
|
|
||||||
]
|
|
||||||
abs_path = select_best_live_path(visible)
|
|
||||||
if not abs_path:
|
|
||||||
return None
|
|
||||||
display_name = os.path.basename(abs_path)
|
|
||||||
for ref in visible:
|
|
||||||
if ref.file_path == abs_path and ref.name:
|
|
||||||
display_name = ref.name
|
|
||||||
break
|
|
||||||
ctype = (
|
|
||||||
asset.mime_type
|
|
||||||
or mimetypes.guess_type(display_name)[0]
|
|
||||||
or "application/octet-stream"
|
|
||||||
)
|
|
||||||
return DownloadResolutionResult(
|
|
||||||
abs_path=abs_path,
|
|
||||||
content_type=ctype,
|
|
||||||
download_name=display_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_asset_for_download(
|
def resolve_asset_for_download(
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
|||||||
@@ -11,14 +11,13 @@ from app.assets.database.queries import (
|
|||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
fetch_reference_and_asset,
|
fetch_reference_and_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
|
get_existing_asset_ids,
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
get_reference_tags,
|
get_reference_tags,
|
||||||
get_or_create_reference,
|
get_or_create_reference,
|
||||||
reference_exists,
|
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_tags,
|
set_reference_tags,
|
||||||
update_asset_hash_and_mime,
|
|
||||||
upsert_asset,
|
upsert_asset,
|
||||||
upsert_reference,
|
upsert_reference,
|
||||||
validate_tags_exist,
|
validate_tags_exist,
|
||||||
@@ -27,7 +26,6 @@ from app.assets.helpers import normalize_tags
|
|||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
get_name_and_tags_from_asset_path,
|
|
||||||
resolve_destination_from_tags,
|
resolve_destination_from_tags,
|
||||||
validate_path_within_base,
|
validate_path_within_base,
|
||||||
)
|
)
|
||||||
@@ -67,7 +65,7 @@ def _ingest_file_from_path(
|
|||||||
|
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
if preview_id:
|
if preview_id:
|
||||||
if not reference_exists(session, preview_id):
|
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
||||||
preview_id = None
|
preview_id = None
|
||||||
|
|
||||||
asset, asset_created, asset_updated = upsert_asset(
|
asset, asset_created, asset_updated = upsert_asset(
|
||||||
@@ -137,8 +135,6 @@ def _register_existing_asset(
|
|||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
tag_origin: str = "manual",
|
tag_origin: str = "manual",
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
mime_type: str | None = None,
|
|
||||||
preview_id: str | None = None,
|
|
||||||
) -> RegisterAssetResult:
|
) -> RegisterAssetResult:
|
||||||
user_metadata = user_metadata or {}
|
user_metadata = user_metadata or {}
|
||||||
|
|
||||||
@@ -147,25 +143,14 @@ def _register_existing_asset(
|
|||||||
if not asset:
|
if not asset:
|
||||||
raise ValueError(f"No asset with hash {asset_hash}")
|
raise ValueError(f"No asset with hash {asset_hash}")
|
||||||
|
|
||||||
if mime_type and not asset.mime_type:
|
|
||||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
|
|
||||||
|
|
||||||
if preview_id:
|
|
||||||
if not reference_exists(session, preview_id):
|
|
||||||
preview_id = None
|
|
||||||
|
|
||||||
ref, ref_created = get_or_create_reference(
|
ref, ref_created = get_or_create_reference(
|
||||||
session,
|
session,
|
||||||
asset_id=asset.id,
|
asset_id=asset.id,
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
name=name,
|
name=name,
|
||||||
preview_id=preview_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not ref_created:
|
if not ref_created:
|
||||||
if preview_id and ref.preview_id != preview_id:
|
|
||||||
ref.preview_id = preview_id
|
|
||||||
|
|
||||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||||
result = RegisterAssetResult(
|
result = RegisterAssetResult(
|
||||||
ref=extract_reference_data(ref),
|
ref=extract_reference_data(ref),
|
||||||
@@ -257,8 +242,6 @@ def upload_from_temp_path(
|
|||||||
client_filename: str | None = None,
|
client_filename: str | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
expected_hash: str | None = None,
|
expected_hash: str | None = None,
|
||||||
mime_type: str | None = None,
|
|
||||||
preview_id: str | None = None,
|
|
||||||
) -> UploadResult:
|
) -> UploadResult:
|
||||||
try:
|
try:
|
||||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||||
@@ -287,8 +270,6 @@ def upload_from_temp_path(
|
|||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
mime_type=mime_type,
|
|
||||||
preview_id=preview_id,
|
|
||||||
)
|
)
|
||||||
return UploadResult(
|
return UploadResult(
|
||||||
ref=result.ref,
|
ref=result.ref,
|
||||||
@@ -310,7 +291,7 @@ def upload_from_temp_path(
|
|||||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||||
validate_path_within_base(dest_abs, base_dir)
|
validate_path_within_base(dest_abs, base_dir)
|
||||||
|
|
||||||
content_type = mime_type or (
|
content_type = (
|
||||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||||
or "application/octet-stream"
|
or "application/octet-stream"
|
||||||
@@ -334,7 +315,7 @@ def upload_from_temp_path(
|
|||||||
mime_type=content_type,
|
mime_type=content_type,
|
||||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
preview_id=preview_id,
|
preview_id=None,
|
||||||
user_metadata=user_metadata or {},
|
user_metadata=user_metadata or {},
|
||||||
tags=tags,
|
tags=tags,
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
@@ -361,99 +342,30 @@ def upload_from_temp_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_file_in_place(
|
|
||||||
abs_path: str,
|
|
||||||
name: str,
|
|
||||||
tags: list[str],
|
|
||||||
owner_id: str = "",
|
|
||||||
mime_type: str | None = None,
|
|
||||||
) -> UploadResult:
|
|
||||||
"""Register an already-saved file in the asset database without moving it.
|
|
||||||
|
|
||||||
Tags are derived from the filesystem path (root category + subfolder names),
|
|
||||||
merged with any caller-provided tags, matching the behavior of the scanner.
|
|
||||||
If the path is not under a known root, only the caller-provided tags are used.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
|
||||||
except ValueError:
|
|
||||||
path_tags = []
|
|
||||||
merged_tags = normalize_tags([*path_tags, *tags])
|
|
||||||
|
|
||||||
try:
|
|
||||||
digest, _ = hashing.compute_blake3_hash(abs_path)
|
|
||||||
except ImportError as e:
|
|
||||||
raise DependencyMissingError(str(e))
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"failed to hash file: {e}")
|
|
||||||
asset_hash = "blake3:" + digest
|
|
||||||
|
|
||||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
|
||||||
content_type = mime_type or (
|
|
||||||
mimetypes.guess_type(abs_path, strict=False)[0]
|
|
||||||
or "application/octet-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
ingest_result = _ingest_file_from_path(
|
|
||||||
abs_path=abs_path,
|
|
||||||
asset_hash=asset_hash,
|
|
||||||
size_bytes=size_bytes,
|
|
||||||
mtime_ns=mtime_ns,
|
|
||||||
mime_type=content_type,
|
|
||||||
info_name=_sanitize_filename(name, fallback=digest),
|
|
||||||
owner_id=owner_id,
|
|
||||||
tags=merged_tags,
|
|
||||||
tag_origin="upload",
|
|
||||||
require_existing_tags=False,
|
|
||||||
)
|
|
||||||
reference_id = ingest_result.reference_id
|
|
||||||
if not reference_id:
|
|
||||||
raise RuntimeError("failed to create asset reference")
|
|
||||||
|
|
||||||
with create_session() as session:
|
|
||||||
pair = fetch_reference_and_asset(
|
|
||||||
session, reference_id=reference_id, owner_id=owner_id
|
|
||||||
)
|
|
||||||
if not pair:
|
|
||||||
raise RuntimeError("inconsistent DB state after ingest")
|
|
||||||
ref, asset = pair
|
|
||||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
|
||||||
|
|
||||||
return UploadResult(
|
|
||||||
ref=extract_reference_data(ref),
|
|
||||||
asset=extract_asset_data(asset),
|
|
||||||
tags=tag_names,
|
|
||||||
created_new=ingest_result.asset_created,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_from_hash(
|
def create_from_hash(
|
||||||
hash_str: str,
|
hash_str: str,
|
||||||
name: str,
|
name: str,
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
user_metadata: dict | None = None,
|
user_metadata: dict | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
mime_type: str | None = None,
|
|
||||||
preview_id: str | None = None,
|
|
||||||
) -> UploadResult | None:
|
) -> UploadResult | None:
|
||||||
canonical = hash_str.strip().lower()
|
canonical = hash_str.strip().lower()
|
||||||
|
|
||||||
try:
|
with create_session() as session:
|
||||||
result = _register_existing_asset(
|
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||||
asset_hash=canonical,
|
if not asset:
|
||||||
name=_sanitize_filename(
|
return None
|
||||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
|
||||||
),
|
result = _register_existing_asset(
|
||||||
user_metadata=user_metadata or {},
|
asset_hash=canonical,
|
||||||
tags=tags or [],
|
name=_sanitize_filename(
|
||||||
tag_origin="manual",
|
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||||
owner_id=owner_id,
|
),
|
||||||
mime_type=mime_type,
|
user_metadata=user_metadata or {},
|
||||||
preview_id=preview_id,
|
tags=tags or [],
|
||||||
)
|
tag_origin="manual",
|
||||||
except ValueError:
|
owner_id=owner_id,
|
||||||
logging.warning("create_from_hash: no asset found for hash %s", canonical)
|
)
|
||||||
return None
|
|
||||||
|
|
||||||
return UploadResult(
|
return UploadResult(
|
||||||
ref=result.ref,
|
ref=result.ref,
|
||||||
|
|||||||
@@ -25,9 +25,7 @@ class ReferenceData:
|
|||||||
preview_id: str | None
|
preview_id: str | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
system_metadata: dict[str, Any] | None = None
|
last_access_time: datetime | None
|
||||||
job_id: str | None = None
|
|
||||||
last_access_time: datetime | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -95,8 +93,6 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
|||||||
file_path=ref.file_path,
|
file_path=ref.file_path,
|
||||||
user_metadata=ref.user_metadata,
|
user_metadata=ref.user_metadata,
|
||||||
preview_id=ref.preview_id,
|
preview_id=ref.preview_id,
|
||||||
system_metadata=ref.system_metadata,
|
|
||||||
job_id=ref.job_id,
|
|
||||||
created_at=ref.created_at,
|
created_at=ref.created_at,
|
||||||
updated_at=ref.updated_at,
|
updated_at=ref.updated_at,
|
||||||
last_access_time=ref.last_access_time,
|
last_access_time=ref.last_access_time,
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import Sequence
|
|
||||||
|
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
AddTagsResult,
|
AddTagsResult,
|
||||||
RemoveTagsResult,
|
RemoveTagsResult,
|
||||||
@@ -8,7 +6,6 @@ from app.assets.database.queries import (
|
|||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
remove_tags_from_reference,
|
remove_tags_from_reference,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
|
|
||||||
from app.assets.services.schemas import TagUsage
|
from app.assets.services.schemas import TagUsage
|
||||||
from app.database.db import create_session
|
from app.database.db import create_session
|
||||||
|
|
||||||
@@ -76,23 +73,3 @@ def list_tags(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||||
|
|
||||||
|
|
||||||
def list_tag_histogram(
|
|
||||||
owner_id: str = "",
|
|
||||||
include_tags: Sequence[str] | None = None,
|
|
||||||
exclude_tags: Sequence[str] | None = None,
|
|
||||||
name_contains: str | None = None,
|
|
||||||
metadata_filter: dict | None = None,
|
|
||||||
limit: int = 100,
|
|
||||||
) -> dict[str, int]:
|
|
||||||
with create_session() as session:
|
|
||||||
return list_tag_counts_for_filtered_assets(
|
|
||||||
session,
|
|
||||||
owner_id=owner_id,
|
|
||||||
include_tags=include_tags,
|
|
||||||
exclude_tags=exclude_tags,
|
|
||||||
name_contains=name_contains,
|
|
||||||
metadata_filter=metadata_filter,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,18 +1,9 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import MetaData
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
NAMING_CONVENTION = {
|
|
||||||
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
|
|
||||||
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
|
|
||||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
|
||||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
|
||||||
"pk": "pk_%(table_name)s",
|
|
||||||
}
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
pass
|
||||||
|
|
||||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||||
fields = obj.__table__.columns.keys()
|
fields = obj.__table__.columns.keys()
|
||||||
|
|||||||
@@ -83,8 +83,6 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
|
|||||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||||
|
|
||||||
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
|
|
||||||
|
|
||||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|||||||
@@ -209,39 +209,3 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
|||||||
output_block[i:i + slice_size].copy_(block)
|
output_block[i:i + slice_size].copy_(block)
|
||||||
|
|
||||||
return output_fp4, to_blocked(output_block, flatten=False)
|
return output_fp4, to_blocked(output_block, flatten=False)
|
||||||
|
|
||||||
|
|
||||||
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
|
||||||
def roundup(x_val, multiple):
|
|
||||||
return ((x_val + multiple - 1) // multiple) * multiple
|
|
||||||
|
|
||||||
if pad_32x:
|
|
||||||
rows, cols = x.shape
|
|
||||||
padded_rows = roundup(rows, 32)
|
|
||||||
padded_cols = roundup(cols, 32)
|
|
||||||
if padded_rows != rows or padded_cols != cols:
|
|
||||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
|
||||||
|
|
||||||
F8_E4M3_MAX = 448.0
|
|
||||||
E8M0_BIAS = 127
|
|
||||||
BLOCK_SIZE = 32
|
|
||||||
|
|
||||||
rows, cols = x.shape
|
|
||||||
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
|
||||||
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
|
||||||
|
|
||||||
# E8M0 block scales (power-of-2 exponents)
|
|
||||||
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
|
||||||
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
|
||||||
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
|
||||||
|
|
||||||
zero_mask = (max_abs == 0)
|
|
||||||
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
|
||||||
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
|
||||||
|
|
||||||
# Scale per-block then stochastic round
|
|
||||||
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
|
||||||
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
|
||||||
|
|
||||||
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
|
||||||
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from .causal_conv3d import CausalConv3d
|
|||||||
from .pixel_norm import PixelNorm
|
from .pixel_norm import PixelNorm
|
||||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.model_management
|
|
||||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@@ -537,7 +536,7 @@ class Decoder(nn.Module):
|
|||||||
mark_conv3d_ended(self.conv_out)
|
mark_conv3d_ended(self.conv_out)
|
||||||
sample = self.conv_out(sample, causal=self.causal)
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
if sample is not None and sample.shape[2] > 0:
|
if sample is not None and sample.shape[2] > 0:
|
||||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
output.append(sample)
|
||||||
return
|
return
|
||||||
|
|
||||||
up_block = self.up_blocks[idx]
|
up_block = self.up_blocks[idx]
|
||||||
|
|||||||
@@ -1,68 +1,9 @@
|
|||||||
import math
|
import math
|
||||||
import ctypes
|
|
||||||
import threading
|
|
||||||
import dataclasses
|
|
||||||
import torch
|
import torch
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
class TensorFileSlice(NamedTuple):
|
|
||||||
file_ref: object
|
|
||||||
thread_id: int
|
|
||||||
offset: int
|
|
||||||
size: int
|
|
||||||
|
|
||||||
|
|
||||||
def read_tensor_file_slice_into(tensor, destination):
|
|
||||||
|
|
||||||
if isinstance(tensor, QuantizedTensor):
|
|
||||||
if not isinstance(destination, QuantizedTensor):
|
|
||||||
return False
|
|
||||||
if tensor._layout_cls != destination._layout_cls:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
|
||||||
return False
|
|
||||||
|
|
||||||
dst_orig_dtype = destination._params.orig_dtype
|
|
||||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
|
||||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
|
||||||
return True
|
|
||||||
|
|
||||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
|
||||||
if info is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
file_obj = info.file_ref
|
|
||||||
if (destination.device.type != "cpu"
|
|
||||||
or file_obj is None
|
|
||||||
or threading.get_ident() != info.thread_id
|
|
||||||
or destination.numel() * destination.element_size() < info.size):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if info.size == 0:
|
|
||||||
return True
|
|
||||||
|
|
||||||
buf_type = ctypes.c_ubyte * info.size
|
|
||||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_obj.seek(info.offset)
|
|
||||||
done = 0
|
|
||||||
while done < info.size:
|
|
||||||
try:
|
|
||||||
n = file_obj.readinto(view[done:])
|
|
||||||
except OSError:
|
|
||||||
return False
|
|
||||||
if n <= 0:
|
|
||||||
return False
|
|
||||||
done += n
|
|
||||||
return True
|
|
||||||
finally:
|
|
||||||
view.release()
|
|
||||||
|
|
||||||
class TensorGeometry(NamedTuple):
|
class TensorGeometry(NamedTuple):
|
||||||
shape: any
|
shape: any
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|||||||
@@ -400,7 +400,7 @@ try:
|
|||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if rocm_version >= (7, 0):
|
if rocm_version >= (7, 0):
|
||||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||||
@@ -505,28 +505,6 @@ def module_size(module):
|
|||||||
module_mem += t.nbytes
|
module_mem += t.nbytes
|
||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
def module_mmap_residency(module, free=False):
|
|
||||||
mmap_touched_mem = 0
|
|
||||||
module_mem = 0
|
|
||||||
bounced_mmaps = set()
|
|
||||||
sd = module.state_dict()
|
|
||||||
for k in sd:
|
|
||||||
t = sd[k]
|
|
||||||
module_mem += t.nbytes
|
|
||||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
|
||||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
|
||||||
continue
|
|
||||||
mmap_touched_mem += t.nbytes
|
|
||||||
if not free:
|
|
||||||
continue
|
|
||||||
storage._comfy_tensor_mmap_touched = False
|
|
||||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
|
||||||
if mmap_obj in bounced_mmaps:
|
|
||||||
continue
|
|
||||||
mmap_obj.bounce()
|
|
||||||
bounced_mmaps.add(mmap_obj)
|
|
||||||
return mmap_touched_mem, module_mem
|
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
@@ -554,9 +532,6 @@ class LoadedModel:
|
|||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
|
||||||
def model_mmap_residency(self, free=False):
|
|
||||||
return self.model.model_mmap_residency(free=free)
|
|
||||||
|
|
||||||
def model_loaded_memory(self):
|
def model_loaded_memory(self):
|
||||||
return self.model.loaded_size()
|
return self.model.loaded_size()
|
||||||
|
|
||||||
@@ -658,7 +633,7 @@ def extra_reserved_memory():
|
|||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
@@ -671,14 +646,13 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
can_unload_sorted = sorted(can_unload)
|
for x in sorted(can_unload):
|
||||||
for x in can_unload_sorted:
|
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
memory_to_free = 1e32
|
memory_to_free = 1e32
|
||||||
pins_to_free = 1e32
|
ram_to_free = 1e32
|
||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
memory_to_free = memory_required - get_free_memory(device)
|
memory_to_free = memory_required - get_free_memory(device)
|
||||||
pins_to_free = pins_required - get_free_ram()
|
ram_to_free = ram_required - get_free_ram()
|
||||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
#as that works on-demand.
|
#as that works on-demand.
|
||||||
@@ -687,18 +661,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
if pins_to_free > 0:
|
if ram_to_free > 0:
|
||||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
|
||||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
|
||||||
|
|
||||||
for x in can_unload_sorted:
|
|
||||||
i = x[-1]
|
|
||||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
|
||||||
if ram_to_free <= 0 and i not in unloaded_model:
|
|
||||||
continue
|
|
||||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
|
||||||
if resident_memory > 0:
|
|
||||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
@@ -764,27 +729,17 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
|
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
total_pins_required = {}
|
|
||||||
total_ram_required = {}
|
total_ram_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
device = loaded_model.device
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
#want to do.
|
||||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
#FIXME: This should subtract off the to_load current pin consumption.
|
||||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||||
#make this JIT to keep as much pinned as possible.
|
|
||||||
pins_required = model_memory - pinned_memory
|
|
||||||
ram_required = model_memory - resident_memory
|
|
||||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
|
||||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||||
device,
|
|
||||||
for_dynamic=free_for_dynamic,
|
|
||||||
pins_required=total_pins_required[device],
|
|
||||||
ram_required=total_ram_required[device])
|
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
@@ -1050,12 +1005,6 @@ def intermediate_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def intermediate_dtype():
|
|
||||||
if args.fp16_intermediates:
|
|
||||||
return torch.float16
|
|
||||||
else:
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
def vae_device():
|
def vae_device():
|
||||||
if args.cpu_vae:
|
if args.cpu_vae:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
@@ -1276,11 +1225,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
dest_view = dest_views.pop(0)
|
dest_view = dest_views.pop(0)
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
continue
|
continue
|
||||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
|
||||||
continue
|
|
||||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
|
||||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
|
||||||
storage._comfy_tensor_mmap_touched = True
|
|
||||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
@@ -1718,19 +1662,6 @@ def supports_nvfp4_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def supports_mxfp8_compute(device=None):
|
|
||||||
if not is_nvidia():
|
|
||||||
return False
|
|
||||||
|
|
||||||
if torch_version_numeric < (2, 10):
|
|
||||||
return False
|
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
|
||||||
if props.major < 10:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
|
|||||||
@@ -297,9 +297,6 @@ class ModelPatcher:
|
|||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
def model_mmap_residency(self, free=False):
|
|
||||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
|
||||||
|
|
||||||
def get_ram_usage(self):
|
def get_ram_usage(self):
|
||||||
return self.model_size()
|
return self.model_size()
|
||||||
|
|
||||||
@@ -1066,10 +1063,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def pinned_memory_size(self):
|
|
||||||
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -1660,16 +1653,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
return freed
|
return freed
|
||||||
|
|
||||||
def pinned_memory_size(self):
|
|
||||||
total = 0
|
|
||||||
loading = self._load_list(for_dynamic=True)
|
|
||||||
for x in loading:
|
|
||||||
_, _, _, _, m, _ = x
|
|
||||||
pin = comfy.pinned_memory.get_pin(m)
|
|
||||||
if pin is not None:
|
|
||||||
total += pin.numel() * pin.element_size()
|
|
||||||
return total
|
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
|
|||||||
135
comfy/ops.py
135
comfy/ops.py
@@ -306,40 +306,10 @@ class CastWeightBiasOp:
|
|||||||
bias_function = []
|
bias_function = []
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
@staticmethod
|
|
||||||
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
|
||||||
missing_keys, unexpected_keys, weight_shape,
|
|
||||||
bias_shape=None):
|
|
||||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
|
||||||
prefix_len = len(prefix)
|
|
||||||
for k, v in state_dict.items():
|
|
||||||
key = k[prefix_len:]
|
|
||||||
if key == "weight":
|
|
||||||
if not assign_to_params_buffers:
|
|
||||||
v = v.clone()
|
|
||||||
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
elif bias_shape is not None and key == "bias" and v is not None:
|
|
||||||
if not assign_to_params_buffers:
|
|
||||||
v = v.clone()
|
|
||||||
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
else:
|
|
||||||
unexpected_keys.append(k)
|
|
||||||
|
|
||||||
if module.weight is None:
|
|
||||||
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
|
||||||
missing_keys.append(prefix + "weight")
|
|
||||||
|
|
||||||
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
|
||||||
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
|
||||||
missing_keys.append(prefix + "bias")
|
|
||||||
|
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
# don't trust subclasses that BYO state dict loader to call us.
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
if (not comfy.model_management.WINDOWS
|
|
||||||
or not comfy.memory_management.aimdo_enabled
|
|
||||||
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -360,21 +330,32 @@ class disable_weight_init:
|
|||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
if (not comfy.model_management.WINDOWS
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
or not comfy.memory_management.aimdo_enabled
|
|
||||||
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
|
||||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
disable_weight_init._lazy_load_from_state_dict(
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
self,
|
prefix_len = len(prefix)
|
||||||
state_dict,
|
for k,v in state_dict.items():
|
||||||
prefix,
|
if k[prefix_len:] == "weight":
|
||||||
local_metadata,
|
if not assign_to_params_buffers:
|
||||||
missing_keys,
|
v = v.clone()
|
||||||
unexpected_keys,
|
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||||
weight_shape=(self.in_features, self.out_features),
|
elif k[prefix_len:] == "bias" and v is not None:
|
||||||
bias_shape=(self.out_features,),
|
if not assign_to_params_buffers:
|
||||||
)
|
v = v.clone()
|
||||||
|
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
else:
|
||||||
|
unexpected_keys.append(k)
|
||||||
|
|
||||||
|
#Reconcile default construction of the weight if its missing.
|
||||||
|
if self.weight is None:
|
||||||
|
v = torch.zeros(self.in_features, self.out_features)
|
||||||
|
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
missing_keys.append(prefix+"weight")
|
||||||
|
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||||
|
v = torch.zeros(self.out_features,)
|
||||||
|
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
missing_keys.append(prefix+"bias")
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
@@ -566,53 +547,6 @@ class disable_weight_init:
|
|||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
|
||||||
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
|
||||||
_freeze=False, device=None, dtype=None):
|
|
||||||
# don't trust subclasses that BYO state dict loader to call us.
|
|
||||||
if (not comfy.model_management.WINDOWS
|
|
||||||
or not comfy.memory_management.aimdo_enabled
|
|
||||||
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
|
||||||
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
|
||||||
norm_type, scale_grad_by_freq, sparse, _weight,
|
|
||||||
_freeze, device, dtype)
|
|
||||||
return
|
|
||||||
|
|
||||||
torch.nn.Module.__init__(self)
|
|
||||||
self.num_embeddings = num_embeddings
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
self.padding_idx = padding_idx
|
|
||||||
self.max_norm = max_norm
|
|
||||||
self.norm_type = norm_type
|
|
||||||
self.scale_grad_by_freq = scale_grad_by_freq
|
|
||||||
self.sparse = sparse
|
|
||||||
# Keep shape/dtype visible for module introspection without reserving storage.
|
|
||||||
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
|
||||||
self.weight = torch.nn.Parameter(
|
|
||||||
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
self.bias = None
|
|
||||||
self.weight_comfy_model_dtype = dtype
|
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
|
||||||
|
|
||||||
if (not comfy.model_management.WINDOWS
|
|
||||||
or not comfy.memory_management.aimdo_enabled
|
|
||||||
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
|
||||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
|
||||||
disable_weight_init._lazy_load_from_state_dict(
|
|
||||||
self,
|
|
||||||
state_dict,
|
|
||||||
prefix,
|
|
||||||
local_metadata,
|
|
||||||
missing_keys,
|
|
||||||
unexpected_keys,
|
|
||||||
weight_shape=(self.num_embeddings, self.embedding_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
@@ -867,22 +801,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
orig_shape=(self.out_features, self.in_features),
|
orig_shape=(self.out_features, self.in_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.quant_format == "mxfp8":
|
|
||||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
|
||||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
|
||||||
dtype=torch.uint8)
|
|
||||||
|
|
||||||
if block_scale is None:
|
|
||||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
|
||||||
|
|
||||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
|
||||||
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=block_scale,
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=(self.out_features, self.in_features),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.quant_format == "nvfp4":
|
elif self.quant_format == "nvfp4":
|
||||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
@@ -1032,15 +950,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||||
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info("Using mixed precision operations")
|
logging.info("Using mixed precision operations")
|
||||||
disabled = set()
|
disabled = set()
|
||||||
if not nvfp4_compute:
|
if not nvfp4_compute:
|
||||||
disabled.add("nvfp4")
|
disabled.add("nvfp4")
|
||||||
if not mxfp8_compute:
|
|
||||||
disabled.add("mxfp8")
|
|
||||||
if not fp8_compute:
|
if not fp8_compute:
|
||||||
disabled.add("float8_e4m3fn")
|
disabled.add("float8_e4m3fn")
|
||||||
disabled.add("float8_e5m2")
|
disabled.add("float8_e5m2")
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy_aimdo.host_buffer
|
|
||||||
import comfy_aimdo.torch
|
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
@@ -13,31 +12,18 @@ def pin_memory(module):
|
|||||||
return
|
return
|
||||||
#FIXME: This is a RAM cache trigger event
|
#FIXME: This is a RAM cache trigger event
|
||||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||||
|
pin = torch.empty((size,), dtype=torch.uint8)
|
||||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
if comfy.model_management.pin_memory(pin):
|
||||||
|
module._pin = pin
|
||||||
|
else:
|
||||||
module.pin_failed = True
|
module.pin_failed = True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
|
||||||
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
|
||||||
except RuntimeError:
|
|
||||||
module.pin_failed = True
|
|
||||||
return False
|
|
||||||
|
|
||||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
|
||||||
module._pin_hostbuf = hostbuf
|
|
||||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def unpin_memory(module):
|
def unpin_memory(module):
|
||||||
if get_pin(module) is None:
|
if get_pin(module) is None:
|
||||||
return 0
|
return 0
|
||||||
size = module._pin.numel() * module._pin.element_size()
|
size = module._pin.numel() * module._pin.element_size()
|
||||||
|
comfy.model_management.unpin_memory(module._pin)
|
||||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
|
||||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
|
||||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
|
||||||
|
|
||||||
del module._pin
|
del module._pin
|
||||||
del module._pin_hostbuf
|
|
||||||
return size
|
return size
|
||||||
|
|||||||
@@ -43,18 +43,6 @@ except ImportError as e:
|
|||||||
def get_layout_class(name):
|
def get_layout_class(name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
_CK_MXFP8_AVAILABLE = False
|
|
||||||
if _CK_AVAILABLE:
|
|
||||||
try:
|
|
||||||
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
|
||||||
_CK_MXFP8_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
|
||||||
|
|
||||||
if not _CK_MXFP8_AVAILABLE:
|
|
||||||
class _CKMxfp8Layout:
|
|
||||||
pass
|
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@@ -96,31 +84,6 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
|||||||
return qdata, params
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
|
||||||
@classmethod
|
|
||||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
|
||||||
if tensor.dim() != 2:
|
|
||||||
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
|
||||||
|
|
||||||
orig_dtype = tensor.dtype
|
|
||||||
orig_shape = tuple(tensor.shape)
|
|
||||||
|
|
||||||
padded_shape = cls.get_padded_shape(orig_shape)
|
|
||||||
needs_padding = padded_shape != orig_shape
|
|
||||||
|
|
||||||
if stochastic_rounding > 0:
|
|
||||||
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
|
||||||
else:
|
|
||||||
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
|
||||||
|
|
||||||
params = cls.Params(
|
|
||||||
scale=block_scale,
|
|
||||||
orig_dtype=orig_dtype,
|
|
||||||
orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
return qdata, params
|
|
||||||
|
|
||||||
|
|
||||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
@@ -174,8 +137,6 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
|||||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
if _CK_MXFP8_AVAILABLE:
|
|
||||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
@@ -196,14 +157,6 @@ QUANT_ALGOS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if _CK_MXFP8_AVAILABLE:
|
|
||||||
QUANT_ALGOS["mxfp8"] = {
|
|
||||||
"storage_t": torch.float8_e4m3fn,
|
|
||||||
"parameters": {"weight_scale", "input_scale"},
|
|
||||||
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
|
||||||
"group_size": 32,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Re-exports for backward compatibility
|
# Re-exports for backward compatibility
|
||||||
|
|||||||
27
comfy/sd.py
27
comfy/sd.py
@@ -871,16 +871,13 @@ class VAE:
|
|||||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||||
return pixels
|
return pixels
|
||||||
|
|
||||||
def vae_output_dtype(self):
|
|
||||||
return model_management.intermediate_dtype()
|
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
output = self.process_output(
|
output = self.process_output(
|
||||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
@@ -890,16 +887,16 @@ class VAE:
|
|||||||
|
|
||||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||||
if samples.ndim == 3:
|
if samples.ndim == 3:
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
else:
|
else:
|
||||||
og_shape = samples.shape
|
og_shape = samples.shape
|
||||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||||
|
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
@@ -908,7 +905,7 @@ class VAE:
|
|||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
@@ -917,7 +914,7 @@ class VAE:
|
|||||||
|
|
||||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||||
if self.latent_dim == 1:
|
if self.latent_dim == 1:
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
out_channels = self.latent_channels
|
out_channels = self.latent_channels
|
||||||
upscale_amount = 1 / self.downscale_ratio
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
else:
|
else:
|
||||||
@@ -926,7 +923,7 @@ class VAE:
|
|||||||
tile_x = tile_x // extra_channel_size
|
tile_x = tile_x // extra_channel_size
|
||||||
overlap = overlap // extra_channel_size
|
overlap = overlap // extra_channel_size
|
||||||
upscale_amount = 1 / self.downscale_ratio
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||||
|
|
||||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||||
if self.latent_dim == 1:
|
if self.latent_dim == 1:
|
||||||
@@ -935,7 +932,7 @@ class VAE:
|
|||||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||||
|
|
||||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in, vae_options={}):
|
def decode(self, samples_in, vae_options={}):
|
||||||
@@ -953,9 +950,9 @@ class VAE:
|
|||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||||
if pixel_samples is None:
|
if pixel_samples is None:
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_management.raise_non_oom(e)
|
model_management.raise_non_oom(e)
|
||||||
@@ -1028,9 +1025,9 @@ class VAE:
|
|||||||
samples = None
|
samples = None
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||||
if samples is None:
|
if samples is None:
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
samples[x:x + batch_number] = out
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -20,8 +20,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import ctypes
|
|
||||||
import os
|
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -34,7 +32,7 @@ from einops import rearrange
|
|||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import threading
|
import mmap
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
@@ -83,17 +81,14 @@ _TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def load_safetensors(ckpt):
|
def load_safetensors(ckpt):
|
||||||
import comfy_aimdo.model_mmap
|
f = open(ckpt, "rb")
|
||||||
|
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||||
|
mv = memoryview(mapping)
|
||||||
|
|
||||||
f = open(ckpt, "rb", buffering=0)
|
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||||
file_size = os.path.getsize(ckpt)
|
|
||||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
|
||||||
|
|
||||||
header_size = struct.unpack("<Q", mv[:8])[0]
|
mv = mv[8 + header_size:]
|
||||||
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
|
|
||||||
|
|
||||||
mv = mv[(data_base_offset := 8 + header_size):]
|
|
||||||
|
|
||||||
sd = {}
|
sd = {}
|
||||||
for name, info in header.items():
|
for name, info in header.items():
|
||||||
@@ -107,14 +102,7 @@ def load_safetensors(ckpt):
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
#We are working with read-only RAM by design
|
#We are working with read-only RAM by design
|
||||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||||
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||||
storage = tensor.untyped_storage()
|
|
||||||
setattr(storage,
|
|
||||||
"_comfy_tensor_file_slice",
|
|
||||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
|
||||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
|
||||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
|
||||||
sd[name] = tensor
|
|
||||||
|
|
||||||
return sd, header.get("__metadata__", {}),
|
return sd, header.get("__metadata__", {}),
|
||||||
|
|
||||||
|
|||||||
@@ -1459,7 +1459,6 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
node_id="KlingOmniProEditVideoNode",
|
node_id="KlingOmniProEditVideoNode",
|
||||||
display_name="Kling 3.0 Omni Edit Video",
|
display_name="Kling 3.0 Omni Edit Video",
|
||||||
category="api node/video/Kling",
|
category="api node/video/Kling",
|
||||||
essentials_category="Video Generation",
|
|
||||||
description="Edit an existing video with the latest model from Kling.",
|
description="Edit an existing video with the latest model from Kling.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||||
|
|||||||
@@ -833,7 +833,6 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
|||||||
node_id="RecraftVectorizeImageNode",
|
node_id="RecraftVectorizeImageNode",
|
||||||
display_name="Recraft Vectorize Image",
|
display_name="Recraft Vectorize Image",
|
||||||
category="api node/image/Recraft",
|
category="api node/image/Recraft",
|
||||||
essentials_category="Image Tools",
|
|
||||||
description="Generates SVG synchronously from an input image.",
|
description="Generates SVG synchronously from an input image.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ class EmptyLatentAudio(IO.ComfyNode):
|
|||||||
node_id="EmptyLatentAudio",
|
node_id="EmptyLatentAudio",
|
||||||
display_name="Empty Latent Audio",
|
display_name="Empty Latent Audio",
|
||||||
category="latent/audio",
|
category="latent/audio",
|
||||||
essentials_category="Audio",
|
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@@ -186,7 +185,6 @@ class SaveAudioMP3(IO.ComfyNode):
|
|||||||
search_aliases=["export mp3"],
|
search_aliases=["export mp3"],
|
||||||
display_name="Save Audio (MP3)",
|
display_name="Save Audio (MP3)",
|
||||||
category="audio",
|
category="audio",
|
||||||
essentials_category="Audio",
|
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Audio.Input("audio"),
|
IO.Audio.Input("audio"),
|
||||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ class ImageCompare(IO.ComfyNode):
|
|||||||
display_name="Image Compare",
|
display_name="Image Compare",
|
||||||
description="Compares two images side by side with a slider.",
|
description="Compares two images side by side with a slider.",
|
||||||
category="image",
|
category="image",
|
||||||
essentials_category="Image Tools",
|
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@@ -58,7 +58,6 @@ class ImageCropV2(IO.ComfyNode):
|
|||||||
search_aliases=["trim"],
|
search_aliases=["trim"],
|
||||||
display_name="Image Crop",
|
display_name="Image Crop",
|
||||||
category="image/transform",
|
category="image/transform",
|
||||||
essentials_category="Image Tools",
|
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ class Blend(io.ComfyNode):
|
|||||||
node_id="ImageBlend",
|
node_id="ImageBlend",
|
||||||
display_name="Image Blend",
|
display_name="Image Blend",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
essentials_category="Image Tools",
|
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image1"),
|
io.Image.Input("image1"),
|
||||||
io.Image.Input("image2"),
|
io.Image.Input("image2"),
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b5
|
comfyui_manager==4.1b2
|
||||||
@@ -32,7 +32,7 @@ async def cache_control(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
||||||
response.headers.setdefault("Cache-Control", "no-store")
|
response.headers.setdefault("Cache-Control", "no-cache")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Early return for non-image files - no cache headers needed
|
# Early return for non-image files - no cache headers needed
|
||||||
|
|||||||
14
nodes.py
14
nodes.py
@@ -81,7 +81,6 @@ class CLIPTextEncode(ComfyNodeABC):
|
|||||||
|
|
||||||
|
|
||||||
class ConditioningCombine:
|
class ConditioningCombine:
|
||||||
ESSENTIALS_CATEGORY = "Image Generation"
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
|
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
|
||||||
@@ -1212,6 +1211,9 @@ class GLIGENTextBoxApply:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class EmptyLatentImage:
|
class EmptyLatentImage:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@@ -1230,7 +1232,7 @@ class EmptyLatentImage:
|
|||||||
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
|
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||||
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
|
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
|
||||||
|
|
||||||
|
|
||||||
@@ -1722,8 +1724,6 @@ class LoadImage:
|
|||||||
output_masks = []
|
output_masks = []
|
||||||
w, h = None, None
|
w, h = None, None
|
||||||
|
|
||||||
dtype = comfy.model_management.intermediate_dtype()
|
|
||||||
|
|
||||||
for i in ImageSequence.Iterator(img):
|
for i in ImageSequence.Iterator(img):
|
||||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||||
|
|
||||||
@@ -1748,8 +1748,8 @@ class LoadImage:
|
|||||||
mask = 1. - torch.from_numpy(mask)
|
mask = 1. - torch.from_numpy(mask)
|
||||||
else:
|
else:
|
||||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||||
output_images.append(image.to(dtype=dtype))
|
output_images.append(image)
|
||||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
output_masks.append(mask.unsqueeze(0))
|
||||||
|
|
||||||
if img.format == "MPO":
|
if img.format == "MPO":
|
||||||
break # ignore all frames except the first one for MPO format
|
break # ignore all frames except the first one for MPO format
|
||||||
@@ -1779,7 +1779,6 @@ class LoadImage:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
class LoadImageMask:
|
class LoadImageMask:
|
||||||
ESSENTIALS_CATEGORY = "Image Tools"
|
|
||||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||||
|
|
||||||
_color_channels = ["alpha", "red", "green", "blue"]
|
_color_channels = ["alpha", "red", "green", "blue"]
|
||||||
@@ -1888,7 +1887,6 @@ class ImageScale:
|
|||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class ImageScaleBy:
|
class ImageScaleBy:
|
||||||
ESSENTIALS_CATEGORY = "Image Tools"
|
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.41.20
|
comfyui-frontend-package==1.41.18
|
||||||
comfyui-workflow-templates==0.9.21
|
comfyui-workflow-templates==0.9.21
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
@@ -23,7 +23,7 @@ SQLAlchemy
|
|||||||
filelock
|
filelock
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.8
|
comfy-kitchen>=0.2.8
|
||||||
comfy-aimdo>=0.2.12
|
comfy-aimdo>=0.2.10
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
81
server.py
81
server.py
@@ -35,8 +35,6 @@ from app.frontend_management import FrontendManager, parse_version
|
|||||||
from comfy_api.internal import _ComfyNodeInternal
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
from app.assets.seeder import asset_seeder
|
from app.assets.seeder import asset_seeder
|
||||||
from app.assets.api.routes import register_assets_routes
|
from app.assets.api.routes import register_assets_routes
|
||||||
from app.assets.services.ingest import register_file_in_place
|
|
||||||
from app.assets.services.asset_management import resolve_hash_to_path
|
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
@@ -312,7 +310,7 @@ class PromptServer():
|
|||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
async def get_root(request):
|
async def get_root(request):
|
||||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||||
response.headers['Cache-Control'] = 'no-store, must-revalidate'
|
response.headers['Cache-Control'] = 'no-cache'
|
||||||
response.headers["Pragma"] = "no-cache"
|
response.headers["Pragma"] = "no-cache"
|
||||||
response.headers["Expires"] = "0"
|
response.headers["Expires"] = "0"
|
||||||
return response
|
return response
|
||||||
@@ -421,24 +419,7 @@ class PromptServer():
|
|||||||
with open(filepath, "wb") as f:
|
with open(filepath, "wb") as f:
|
||||||
f.write(image.file.read())
|
f.write(image.file.read())
|
||||||
|
|
||||||
resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
|
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||||
|
|
||||||
if args.enable_assets:
|
|
||||||
try:
|
|
||||||
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
|
||||||
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
|
||||||
resp["asset"] = {
|
|
||||||
"id": result.ref.id,
|
|
||||||
"name": result.ref.name,
|
|
||||||
"asset_hash": result.asset.hash,
|
|
||||||
"size": result.asset.size_bytes,
|
|
||||||
"mime_type": result.asset.mime_type,
|
|
||||||
"tags": result.tags,
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
logging.warning("Failed to register uploaded image as asset", exc_info=True)
|
|
||||||
|
|
||||||
return web.json_response(resp)
|
|
||||||
else:
|
else:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
@@ -498,43 +479,30 @@ class PromptServer():
|
|||||||
async def view_image(request):
|
async def view_image(request):
|
||||||
if "filename" in request.rel_url.query:
|
if "filename" in request.rel_url.query:
|
||||||
filename = request.rel_url.query["filename"]
|
filename = request.rel_url.query["filename"]
|
||||||
|
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||||
|
|
||||||
# The frontend's LoadImage combo widget uses asset_hash values
|
if not filename:
|
||||||
# (e.g. "blake3:...") as widget values. When litegraph renders the
|
return web.Response(status=400)
|
||||||
# node preview, it constructs /view?filename=<asset_hash>, so this
|
|
||||||
# endpoint must resolve blake3 hashes to their on-disk file paths.
|
|
||||||
if filename.startswith("blake3:"):
|
|
||||||
owner_id = self.user_manager.get_request_user_id(request)
|
|
||||||
result = resolve_hash_to_path(filename, owner_id=owner_id)
|
|
||||||
if result is None:
|
|
||||||
return web.Response(status=404)
|
|
||||||
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
|
|
||||||
else:
|
|
||||||
resolved_content_type = None
|
|
||||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
|
||||||
|
|
||||||
if not filename:
|
# validation for security: prevent accessing arbitrary path
|
||||||
return web.Response(status=400)
|
if filename[0] == '/' or '..' in filename:
|
||||||
|
return web.Response(status=400)
|
||||||
|
|
||||||
# validation for security: prevent accessing arbitrary path
|
if output_dir is None:
|
||||||
if filename[0] == '/' or '..' in filename:
|
type = request.rel_url.query.get("type", "output")
|
||||||
return web.Response(status=400)
|
output_dir = folder_paths.get_directory_by_type(type)
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
type = request.rel_url.query.get("type", "output")
|
return web.Response(status=400)
|
||||||
output_dir = folder_paths.get_directory_by_type(type)
|
|
||||||
|
|
||||||
if output_dir is None:
|
if "subfolder" in request.rel_url.query:
|
||||||
return web.Response(status=400)
|
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||||
|
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||||
|
return web.Response(status=403)
|
||||||
|
output_dir = full_output_dir
|
||||||
|
|
||||||
if "subfolder" in request.rel_url.query:
|
filename = os.path.basename(filename)
|
||||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
file = os.path.join(output_dir, filename)
|
||||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
|
||||||
return web.Response(status=403)
|
|
||||||
output_dir = full_output_dir
|
|
||||||
|
|
||||||
filename = os.path.basename(filename)
|
|
||||||
file = os.path.join(output_dir, filename)
|
|
||||||
|
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
if 'preview' in request.rel_url.query:
|
if 'preview' in request.rel_url.query:
|
||||||
@@ -594,13 +562,8 @@ class PromptServer():
|
|||||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||||
else:
|
else:
|
||||||
# Use the content type from asset resolution if available,
|
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||||
# otherwise guess from the filename.
|
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||||
content_type = (
|
|
||||||
resolved_content_type
|
|
||||||
or mimetypes.guess_type(filename)[0]
|
|
||||||
or 'application/octet-stream'
|
|
||||||
)
|
|
||||||
|
|
||||||
# For security, force certain mimetypes to download instead of display
|
# For security, force certain mimetypes to download instead of display
|
||||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
|
|
||||||
|
|
||||||
This catches problems like unnamed FK constraints that prevent batch-mode
|
|
||||||
drop_constraint from working on real SQLite files (see MB-2).
|
|
||||||
|
|
||||||
Migrations 0001 and 0002 are already shipped, so we only exercise
|
|
||||||
upgrade/downgrade for 0003+.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from alembic import command
|
|
||||||
from alembic.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
# Oldest shipped revision — we upgrade to here as a baseline and never
|
|
||||||
# downgrade past it.
|
|
||||||
_BASELINE = "0002_merge_to_asset_references"
|
|
||||||
|
|
||||||
|
|
||||||
def _make_config(db_path: str) -> Config:
|
|
||||||
root = os.path.join(os.path.dirname(__file__), "../..")
|
|
||||||
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
|
|
||||||
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
|
|
||||||
|
|
||||||
cfg = Config(config_path)
|
|
||||||
cfg.set_main_option("script_location", scripts_path)
|
|
||||||
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def migration_db(tmp_path):
|
|
||||||
"""Yield an alembic Config pre-upgraded to the baseline revision."""
|
|
||||||
db_path = str(tmp_path / "test_migration.db")
|
|
||||||
cfg = _make_config(db_path)
|
|
||||||
command.upgrade(cfg, _BASELINE)
|
|
||||||
yield cfg
|
|
||||||
|
|
||||||
|
|
||||||
def test_upgrade_to_head(migration_db):
|
|
||||||
"""Upgrade from baseline to head must succeed on a file-backed DB."""
|
|
||||||
command.upgrade(migration_db, "head")
|
|
||||||
|
|
||||||
|
|
||||||
def test_downgrade_to_baseline(migration_db):
|
|
||||||
"""Upgrade to head then downgrade back to baseline."""
|
|
||||||
command.upgrade(migration_db, "head")
|
|
||||||
command.downgrade(migration_db, _BASELINE)
|
|
||||||
|
|
||||||
|
|
||||||
def test_upgrade_downgrade_cycle(migration_db):
|
|
||||||
"""Full cycle: upgrade → downgrade → upgrade again."""
|
|
||||||
command.upgrade(migration_db, "head")
|
|
||||||
command.downgrade(migration_db, _BASELINE)
|
|
||||||
command.upgrade(migration_db, "head")
|
|
||||||
@@ -10,7 +10,6 @@ from app.assets.database.queries import (
|
|||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
upsert_asset,
|
upsert_asset,
|
||||||
bulk_insert_assets,
|
bulk_insert_assets,
|
||||||
update_asset_hash_and_mime,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -143,45 +142,3 @@ class TestBulkInsertAssets:
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert session.query(Asset).count() == 200
|
assert session.query(Asset).count() == 200
|
||||||
|
|
||||||
|
|
||||||
class TestMimeTypeImmutability:
|
|
||||||
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"initial_mime,second_mime,expected_mime",
|
|
||||||
[
|
|
||||||
("image/png", "image/jpeg", "image/png"),
|
|
||||||
(None, "image/png", "image/png"),
|
|
||||||
],
|
|
||||||
ids=["preserves_existing", "fills_null"],
|
|
||||||
)
|
|
||||||
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
|
|
||||||
h = f"blake3:upsert_{initial_mime}_{second_mime}"
|
|
||||||
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
|
|
||||||
assert created is False
|
|
||||||
assert asset.mime_type == expected_mime
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
|
|
||||||
[
|
|
||||||
(None, "image/png", None, "image/png", "blake3:upd0"),
|
|
||||||
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
|
|
||||||
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
|
|
||||||
],
|
|
||||||
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
|
|
||||||
)
|
|
||||||
def test_update_asset_hash_and_mime_immutability(
|
|
||||||
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
|
|
||||||
):
|
|
||||||
h = expected_hash.removesuffix("_new")
|
|
||||||
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
|
|
||||||
session.add(asset)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
|
|
||||||
assert asset.mime_type == expected_mime
|
|
||||||
assert asset.hash == expected_hash
|
|
||||||
|
|||||||
@@ -242,24 +242,22 @@ class TestSetReferencePreview:
|
|||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
|
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
assert ref.preview_id == preview_ref.id
|
assert ref.preview_id == preview_asset.id
|
||||||
|
|
||||||
def test_clears_preview(self, session: Session):
|
def test_clears_preview(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
ref.preview_id = preview_asset.id
|
||||||
ref.preview_id = preview_ref.id
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
|
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
@@ -267,15 +265,15 @@ class TestSetReferencePreview:
|
|||||||
|
|
||||||
def test_raises_for_nonexistent_reference(self, session: Session):
|
def test_raises_for_nonexistent_reference(self, session: Session):
|
||||||
with pytest.raises(ValueError, match="not found"):
|
with pytest.raises(ValueError, match="not found"):
|
||||||
set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
|
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None)
|
||||||
|
|
||||||
def test_raises_for_nonexistent_preview(self, session: Session):
|
def test_raises_for_nonexistent_preview(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Preview AssetReference"):
|
with pytest.raises(ValueError, match="Preview Asset"):
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
|
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent")
|
||||||
|
|
||||||
|
|
||||||
class TestInsertReference:
|
class TestInsertReference:
|
||||||
@@ -353,14 +351,13 @@ class TestUpdateReferenceTimestamps:
|
|||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
update_reference_timestamps(session, ref, preview_id=preview_ref.id)
|
update_reference_timestamps(session, ref, preview_id=preview_asset.id)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
assert ref.preview_id == preview_ref.id
|
assert ref.preview_id == preview_asset.id
|
||||||
|
|
||||||
|
|
||||||
class TestSetReferenceMetadata:
|
class TestSetReferenceMetadata:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ def _make_reference(
|
|||||||
asset: Asset,
|
asset: Asset,
|
||||||
name: str,
|
name: str,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
system_metadata: dict | None = None,
|
|
||||||
) -> AssetReference:
|
) -> AssetReference:
|
||||||
now = get_utc_now()
|
now = get_utc_now()
|
||||||
ref = AssetReference(
|
ref = AssetReference(
|
||||||
@@ -28,7 +27,6 @@ def _make_reference(
|
|||||||
name=name,
|
name=name,
|
||||||
asset_id=asset.id,
|
asset_id=asset.id,
|
||||||
user_metadata=metadata,
|
user_metadata=metadata,
|
||||||
system_metadata=system_metadata,
|
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
last_access_time=now,
|
last_access_time=now,
|
||||||
@@ -36,10 +34,8 @@ def _make_reference(
|
|||||||
session.add(ref)
|
session.add(ref)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
# Build merged projection: {**system_metadata, **user_metadata}
|
if metadata:
|
||||||
merged = {**(system_metadata or {}), **(metadata or {})}
|
for key, val in metadata.items():
|
||||||
if merged:
|
|
||||||
for key, val in merged.items():
|
|
||||||
for row in convert_metadata_to_rows(key, val):
|
for row in convert_metadata_to_rows(key, val):
|
||||||
meta_row = AssetReferenceMeta(
|
meta_row = AssetReferenceMeta(
|
||||||
asset_reference_id=ref.id,
|
asset_reference_id=ref.id,
|
||||||
@@ -186,46 +182,3 @@ class TestMetadataFilterEmptyDict:
|
|||||||
|
|
||||||
refs, _, total = list_references_page(session, metadata_filter={})
|
refs, _, total = list_references_page(session, metadata_filter={})
|
||||||
assert total == 2
|
assert total == 2
|
||||||
|
|
||||||
|
|
||||||
class TestSystemMetadataProjection:
|
|
||||||
"""Tests for system_metadata merging into the filter projection."""
|
|
||||||
|
|
||||||
def test_system_metadata_keys_are_filterable(self, session: Session):
|
|
||||||
"""system_metadata keys should appear in the merged projection."""
|
|
||||||
asset = _make_asset(session, "hash1")
|
|
||||||
_make_reference(
|
|
||||||
session, asset, "with_sys",
|
|
||||||
system_metadata={"source": "scanner"},
|
|
||||||
)
|
|
||||||
_make_reference(session, asset, "without_sys")
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
refs, _, total = list_references_page(
|
|
||||||
session, metadata_filter={"source": "scanner"}
|
|
||||||
)
|
|
||||||
assert total == 1
|
|
||||||
assert refs[0].name == "with_sys"
|
|
||||||
|
|
||||||
def test_user_metadata_overrides_system_metadata(self, session: Session):
|
|
||||||
"""user_metadata should win when both have the same key."""
|
|
||||||
asset = _make_asset(session, "hash1")
|
|
||||||
_make_reference(
|
|
||||||
session, asset, "overridden",
|
|
||||||
metadata={"origin": "user_upload"},
|
|
||||||
system_metadata={"origin": "auto_scan"},
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Should match the user value, not the system value
|
|
||||||
refs, _, total = list_references_page(
|
|
||||||
session, metadata_filter={"origin": "user_upload"}
|
|
||||||
)
|
|
||||||
assert total == 1
|
|
||||||
assert refs[0].name == "overridden"
|
|
||||||
|
|
||||||
# Should NOT match the system value (it was overridden)
|
|
||||||
refs, _, total = list_references_page(
|
|
||||||
session, metadata_filter={"origin": "auto_scan"}
|
|
||||||
)
|
|
||||||
assert total == 0
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from app.assets.services import (
|
|||||||
delete_asset_reference,
|
delete_asset_reference,
|
||||||
set_asset_preview,
|
set_asset_preview,
|
||||||
)
|
)
|
||||||
from app.assets.services.asset_management import resolve_hash_to_path
|
|
||||||
|
|
||||||
|
|
||||||
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
||||||
@@ -220,33 +219,31 @@ class TestSetAssetPreview:
|
|||||||
asset = _make_asset(session, hash_val="blake3:main")
|
asset = _make_asset(session, hash_val="blake3:main")
|
||||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
|
||||||
ref_id = ref.id
|
ref_id = ref.id
|
||||||
preview_ref_id = preview_ref.id
|
preview_id = preview_asset.id
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref_id,
|
reference_id=ref_id,
|
||||||
preview_reference_id=preview_ref_id,
|
preview_asset_id=preview_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify by re-fetching from DB
|
# Verify by re-fetching from DB
|
||||||
session.expire_all()
|
session.expire_all()
|
||||||
updated_ref = session.get(AssetReference, ref_id)
|
updated_ref = session.get(AssetReference, ref_id)
|
||||||
assert updated_ref.preview_id == preview_ref_id
|
assert updated_ref.preview_id == preview_id
|
||||||
|
|
||||||
def test_clears_preview(self, mock_create_session, session: Session):
|
def test_clears_preview(self, mock_create_session, session: Session):
|
||||||
asset = _make_asset(session)
|
asset = _make_asset(session)
|
||||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
ref.preview_id = preview_asset.id
|
||||||
ref.preview_id = preview_ref.id
|
|
||||||
ref_id = ref.id
|
ref_id = ref.id
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref_id,
|
reference_id=ref_id,
|
||||||
preview_reference_id=None,
|
preview_asset_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify by re-fetching from DB
|
# Verify by re-fetching from DB
|
||||||
@@ -266,45 +263,6 @@ class TestSetAssetPreview:
|
|||||||
with pytest.raises(PermissionError, match="not owner"):
|
with pytest.raises(PermissionError, match="not owner"):
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref.id,
|
reference_id=ref.id,
|
||||||
preview_reference_id=None,
|
preview_asset_id=None,
|
||||||
owner_id="user2",
|
owner_id="user2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestResolveHashToPath:
|
|
||||||
def test_returns_none_for_unknown_hash(self, mock_create_session):
|
|
||||||
result = resolve_hash_to_path("blake3:" + "a" * 64)
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"ref_owner, query_owner, expect_found",
|
|
||||||
[
|
|
||||||
("user1", "user1", True),
|
|
||||||
("user1", "user2", False),
|
|
||||||
("", "anyone", True),
|
|
||||||
("", "", True),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"owner_sees_own_ref",
|
|
||||||
"other_owner_blocked",
|
|
||||||
"ownerless_visible_to_anyone",
|
|
||||||
"ownerless_visible_to_empty",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_owner_visibility(
|
|
||||||
self, ref_owner, query_owner, expect_found,
|
|
||||||
mock_create_session, session: Session, temp_dir,
|
|
||||||
):
|
|
||||||
f = temp_dir / "file.bin"
|
|
||||||
f.write_bytes(b"data")
|
|
||||||
asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
|
|
||||||
ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
|
|
||||||
ref.file_path = str(f)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
|
|
||||||
if expect_found:
|
|
||||||
assert result is not None
|
|
||||||
assert result.abs_path == str(f)
|
|
||||||
else:
|
|
||||||
assert result is None
|
|
||||||
|
|||||||
@@ -113,19 +113,11 @@ class TestIngestFileFromPath:
|
|||||||
file_path = temp_dir / "with_preview.bin"
|
file_path = temp_dir / "with_preview.bin"
|
||||||
file_path.write_bytes(b"data")
|
file_path.write_bytes(b"data")
|
||||||
|
|
||||||
# Create a preview asset and reference
|
# Create a preview asset first
|
||||||
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
||||||
session.add(preview_asset)
|
session.add(preview_asset)
|
||||||
session.flush()
|
|
||||||
from app.assets.helpers import get_utc_now
|
|
||||||
now = get_utc_now()
|
|
||||||
preview_ref = AssetReference(
|
|
||||||
asset_id=preview_asset.id, name="preview.png", owner_id="",
|
|
||||||
created_at=now, updated_at=now, last_access_time=now,
|
|
||||||
)
|
|
||||||
session.add(preview_ref)
|
|
||||||
session.commit()
|
session.commit()
|
||||||
preview_id = preview_ref.id
|
preview_id = preview_asset.id
|
||||||
|
|
||||||
result = _ingest_file_from_path(
|
result = _ingest_file_from_path(
|
||||||
abs_path=str(file_path),
|
abs_path=str(file_path),
|
||||||
|
|||||||
@@ -1,123 +0,0 @@
|
|||||||
"""Tests for list_tag_histogram service function."""
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference
|
|
||||||
from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
|
|
||||||
from app.assets.helpers import get_utc_now
|
|
||||||
from app.assets.services.tagging import list_tag_histogram
|
|
||||||
|
|
||||||
|
|
||||||
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
|
|
||||||
asset = Asset(hash=hash_val, size_bytes=1024)
|
|
||||||
session.add(asset)
|
|
||||||
session.flush()
|
|
||||||
return asset
|
|
||||||
|
|
||||||
|
|
||||||
def _make_reference(
|
|
||||||
session: Session,
|
|
||||||
asset: Asset,
|
|
||||||
name: str = "test",
|
|
||||||
owner_id: str = "",
|
|
||||||
) -> AssetReference:
|
|
||||||
now = get_utc_now()
|
|
||||||
ref = AssetReference(
|
|
||||||
owner_id=owner_id,
|
|
||||||
name=name,
|
|
||||||
asset_id=asset.id,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
last_access_time=now,
|
|
||||||
)
|
|
||||||
session.add(ref)
|
|
||||||
session.flush()
|
|
||||||
return ref
|
|
||||||
|
|
||||||
|
|
||||||
class TestListTagHistogram:
|
|
||||||
def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
|
|
||||||
ensure_tags_exist(session, ["alpha", "beta"])
|
|
||||||
a1 = _make_asset(session, "blake3:aaa")
|
|
||||||
r1 = _make_reference(session, a1, name="r1")
|
|
||||||
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
|
|
||||||
|
|
||||||
a2 = _make_asset(session, "blake3:bbb")
|
|
||||||
r2 = _make_reference(session, a2, name="r2")
|
|
||||||
add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
result = list_tag_histogram()
|
|
||||||
|
|
||||||
assert result["alpha"] == 2
|
|
||||||
assert result["beta"] == 1
|
|
||||||
|
|
||||||
def test_empty_when_no_assets(self, mock_create_session, session: Session):
|
|
||||||
ensure_tags_exist(session, ["unused"])
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
result = list_tag_histogram()
|
|
||||||
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
def test_include_tags_filter(self, mock_create_session, session: Session):
|
|
||||||
ensure_tags_exist(session, ["models", "loras", "input"])
|
|
||||||
a1 = _make_asset(session, "blake3:aaa")
|
|
||||||
r1 = _make_reference(session, a1, name="r1")
|
|
||||||
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
|
||||||
|
|
||||||
a2 = _make_asset(session, "blake3:bbb")
|
|
||||||
r2 = _make_reference(session, a2, name="r2")
|
|
||||||
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
result = list_tag_histogram(include_tags=["models"])
|
|
||||||
|
|
||||||
# Only r1 has "models", so only its tags appear
|
|
||||||
assert "models" in result
|
|
||||||
assert "loras" in result
|
|
||||||
assert "input" not in result
|
|
||||||
|
|
||||||
def test_exclude_tags_filter(self, mock_create_session, session: Session):
|
|
||||||
ensure_tags_exist(session, ["models", "loras", "input"])
|
|
||||||
a1 = _make_asset(session, "blake3:aaa")
|
|
||||||
r1 = _make_reference(session, a1, name="r1")
|
|
||||||
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
|
||||||
|
|
||||||
a2 = _make_asset(session, "blake3:bbb")
|
|
||||||
r2 = _make_reference(session, a2, name="r2")
|
|
||||||
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
result = list_tag_histogram(exclude_tags=["models"])
|
|
||||||
|
|
||||||
# r1 excluded, only r2's tags remain
|
|
||||||
assert "input" in result
|
|
||||||
assert "loras" not in result
|
|
||||||
|
|
||||||
def test_name_contains_filter(self, mock_create_session, session: Session):
|
|
||||||
ensure_tags_exist(session, ["alpha", "beta"])
|
|
||||||
a1 = _make_asset(session, "blake3:aaa")
|
|
||||||
r1 = _make_reference(session, a1, name="my_model.safetensors")
|
|
||||||
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
|
|
||||||
|
|
||||||
a2 = _make_asset(session, "blake3:bbb")
|
|
||||||
r2 = _make_reference(session, a2, name="picture.png")
|
|
||||||
add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
result = list_tag_histogram(name_contains="model")
|
|
||||||
|
|
||||||
assert "alpha" in result
|
|
||||||
assert "beta" not in result
|
|
||||||
|
|
||||||
def test_limit_caps_results(self, mock_create_session, session: Session):
|
|
||||||
tags = [f"tag{i}" for i in range(10)]
|
|
||||||
ensure_tags_exist(session, tags)
|
|
||||||
a = _make_asset(session, "blake3:aaa")
|
|
||||||
r = _make_reference(session, a, name="r1")
|
|
||||||
add_tags_to_reference(session, reference_id=r.id, tags=tags)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
result = list_tag_histogram(limit=3)
|
|
||||||
|
|
||||||
assert len(result) == 3
|
|
||||||
@@ -243,15 +243,6 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
|||||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||||
|
|
||||||
|
|
||||||
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
|
|
||||||
files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
|
|
||||||
form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
|
|
||||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
|
||||||
body = r.json()
|
|
||||||
assert r.status_code == 400
|
|
||||||
assert body["error"]["code"] == "INVALID_BODY"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("root", ["input", "output"])
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
def test_duplicate_upload_same_display_name_does_not_clobber(
|
def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||||
root: str,
|
root: str,
|
||||||
|
|||||||
@@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
|
|||||||
},
|
},
|
||||||
# JavaScript/CSS scenarios
|
# JavaScript/CSS scenarios
|
||||||
{
|
{
|
||||||
"name": "js_no_store",
|
"name": "js_no_cache",
|
||||||
"path": "/script.js",
|
"path": "/script.js",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-store",
|
"expected_cache": "no-cache",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "css_no_store",
|
"name": "css_no_cache",
|
||||||
"path": "/styles.css",
|
"path": "/styles.css",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-store",
|
"expected_cache": "no-cache",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "index_json_no_store",
|
"name": "index_json_no_cache",
|
||||||
"path": "/api/index.json",
|
"path": "/api/index.json",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-store",
|
"expected_cache": "no-cache",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "localized_index_json_no_store",
|
"name": "localized_index_json_no_cache",
|
||||||
"path": "/templates/index.zh.json",
|
"path": "/templates/index.zh.json",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-store",
|
"expected_cache": "no-cache",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
# Non-matching files
|
# Non-matching files
|
||||||
|
|||||||
Reference in New Issue
Block a user