Register uploaded images in asset database when --enable-assets is set

Add register_file_in_place() service function to ingest module for
registering already-saved files without moving them. Call it from the
/upload/image endpoint to return asset metadata in the response.

Amp-Thread-ID: https://ampcode.com/threads/T-019ce023-3384-7560-bacf-de40b0de0dd2
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr
2026-03-11 21:35:18 -07:00
parent 31a1c03e3d
commit 0d5d4a3b8a
2 changed files with 79 additions and 2 deletions

View File

@@ -346,6 +346,61 @@ def upload_from_temp_path(
)
def register_file_in_place(
abs_path: str,
name: str,
tags: list[str],
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it."""
try:
digest, _ = hashing.compute_blake3_hash(abs_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash file: {e}")
asset_hash = "blake3:" + digest
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
content_type = mime_type or (
mimetypes.guess_type(abs_path, strict=False)[0]
or "application/octet-stream"
)
ingest_result = _ingest_file_from_path(
abs_path=abs_path,
asset_hash=asset_hash,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name, fallback=digest),
owner_id=owner_id,
tags=tags,
tag_origin="upload",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash(
hash_str: str,
name: str,

View File

@@ -35,6 +35,7 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
@@ -163,7 +164,11 @@ def create_origin_only_middleware():
if host_domain_parsed.port is None:
origin_domain = parsed.hostname
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
# When both host and origin are loopback, allow port differences
# (e.g. frontend dev server on :5173 proxying to backend on :8188)
if loopback and is_loopback(parsed.hostname):
pass
elif loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
@@ -419,7 +424,24 @@ class PromptServer():
with open(filepath, "wb") as f:
f.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
if args.enable_assets:
try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
resp["asset"] = {
"id": result.ref.id,
"name": result.ref.name,
"asset_hash": result.asset.hash if result.asset else None,
"size": result.asset.size_bytes if result.asset else None,
"mime_type": result.asset.mime_type if result.asset else None,
"tags": result.tags,
}
except Exception:
logging.warning("Failed to register uploaded image as asset", exc_info=True)
return web.json_response(resp)
else:
return web.Response(status=400)