feat: remove Asset when there is no references left + bugfixes + more tests

This commit is contained in:
bigcat88
2025-09-08 22:37:39 +03:00
parent 0e9de2b7c9
commit dfb5703d40
7 changed files with 332 additions and 51 deletions

View File

@@ -52,7 +52,6 @@ def comfy_tmp_base_dir() -> Path:
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
_make_base_dirs(tmp)
yield tmp
# cleanup in a best-effort way; ComfyUI should not keep files open in this dir
with contextlib.suppress(Exception):
for p in sorted(tmp.rglob("*"), reverse=True):
if p.is_file() or p.is_symlink():
@@ -72,10 +71,9 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path):
- autoscan disabled
Returns (base_url, process, port)
"""
port = 8500 # _free_port()
port = _free_port()
db_url = "sqlite+aiosqlite:///:memory:"
# stdout/stderr capturing for debugging if something goes wrong
logs_dir = comfy_tmp_base_dir / "logs"
logs_dir.mkdir(exist_ok=True)
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
@@ -138,28 +136,59 @@ def api_base(comfy_url_and_proc) -> str:
return base_url
@pytest.fixture
def make_asset_bytes() -> Callable[[str], bytes]:
def _make(name: str) -> bytes:
# Generate deterministic small content variations based on name
seed = sum(ord(c) for c in name) % 251
data = bytes((i * 31 + seed) % 256 for i in range(8192))
return data
return _make
async def _upload_asset(session: aiohttp.ClientSession, base: str, *, name: str, tags: list[str], meta: dict) -> dict:
make_asset_bytes = bytes((i % 251) for i in range(4096))
async def _post_multipart_asset(
session: aiohttp.ClientSession,
base: str,
*,
name: str,
tags: list[str],
meta: dict,
data: bytes,
extra_fields: dict | None = None,
) -> tuple[int, dict]:
form = aiohttp.FormData()
form.add_field("file", make_asset_bytes, filename=name, content_type="application/octet-stream")
form.add_field("file", data, filename=name, content_type="application/octet-stream")
form.add_field("tags", json.dumps(tags))
form.add_field("name", name)
form.add_field("user_metadata", json.dumps(meta))
if extra_fields:
for k, v in extra_fields.items():
form.add_field(k, v)
async with session.post(base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status in (200, 201), body
return r.status, body
@pytest.fixture
def make_asset_bytes() -> Callable[[str, int], bytes]:
def _make(name: str, size: int = 8192) -> bytes:
seed = sum(ord(c) for c in name) % 251
return bytes((i * 31 + seed) % 256 for i in range(size))
return _make
@pytest_asyncio.fixture
async def asset_factory(http: aiohttp.ClientSession, api_base: str):
"""
Returns create(name, tags, meta, data) -> response dict
Tracks created ids and deletes them after the test.
"""
created: list[str] = []
async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
assert status in (200, 201), body
created.append(body["id"])
return body
yield create
# cleanup by id
for aid in created:
with contextlib.suppress(Exception):
async with http.delete(f"{api_base}/api/assets/{aid}") as r:
await r.read()
@pytest_asyncio.fixture
async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict:
@@ -179,3 +208,25 @@ async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict:
body = await r.json()
assert r.status == 201, body
return body
@pytest_asyncio.fixture(autouse=True)
async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str):
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
yield
while True:
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
) as r:
body = await r.json()
if r.status != 200:
break
ids = [a["id"] for a in body.get("assets", [])]
if not ids:
break
for aid in ids:
with contextlib.suppress(Exception):
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
await dr.read()

View File

@@ -1,26 +0,0 @@
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_tags_listing_endpoint(http: aiohttp.ClientSession, api_base: str):
# Include zero-usage tags by default
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
body1 = await r1.json()
assert r1.status == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
assert "models" in names
assert "checkpoints" in names
# Only used tags
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
body2 = await r2.json()
assert r2.status == 200
# Should contain no tags
assert not [t["name"] for t in body2["tags"]]
# TODO-1: add some asset
# TODO-2: check that "used" tags are now non zero amount
# TODO-3: do a global teardown, so the state of ComfyUI is clear after each test, and all test can be run solo or one-by-one without any problems.

56
tests-assets/test_tags.py Normal file
View File

@@ -0,0 +1,56 @@
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
# Include zero-usage tags by default
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
body1 = await r1.json()
assert r1.status == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
assert "models" in names
assert "checkpoints" in names
# Only used tags before we add anything new from this test cycle
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
body2 = await r2.json()
assert r2.status == 200
# We already seeded one asset via fixture, so used tags must be non-empty
used_names = [t["name"] for t in body2["tags"]]
assert "models" in used_names
assert "checkpoints" in used_names
# Prefix filter should refine the list
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3:
b3 = await r3.json()
assert r3.status == 200
names3 = [t["name"] for t in b3["tags"]]
assert "unit-tests" in names3
assert "models" not in names3 # filtered out by prefix
# Order by name ascending should be stable
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4:
b4 = await r4.json()
assert r4.status == 200
names4 = [t["name"] for t in b4["tags"]]
assert names4 == sorted(names4)
@pytest.mark.asyncio
async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str):
# Include zero-usage tags by default
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
body1 = await r1.json()
assert r1.status == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
assert "models" in names
assert "checkpoints" in names
# With include_zero=False there should be no tags returned for the database without Assets.
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
body2 = await r2.json()
assert r2.status == 200
assert not [t["name"] for t in body2["tags"]]

View File

@@ -0,0 +1,133 @@
import json
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str):
async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r:
body = await r.json()
assert r.status == 415
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
@pytest.mark.asyncio
async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData(default_to_multipart=True)
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"]))
form.add_field("name", "x.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "MISSING_FILE"
@pytest.mark.asyncio
async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream")
form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"]))
form.add_field("name", "m.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"]
@pytest.mark.asyncio
async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream")
# '..' should be rejected by destination resolver
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]))
form.add_field("name", "evil.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
@pytest.mark.asyncio
async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes):
name = "dup_a.safetensors"
tags = ["models", "checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "dup"}
data = make_asset_bytes(name)
form1 = aiohttp.FormData()
form1.add_field("file", data, filename=name, content_type="application/octet-stream")
form1.add_field("tags", json.dumps(tags))
form1.add_field("name", name)
form1.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form1) as r1:
a1 = await r1.json()
assert r1.status == 201, a1
assert a1["created_new"] is True
# Second upload with the same data and name should return created_new == False and the same asset
form2 = aiohttp.FormData()
form2.add_field("file", data, filename=name, content_type="application/octet-stream")
form2.add_field("tags", json.dumps(tags))
form2.add_field("name", name)
form2.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form2) as r2:
a2 = await r2.json()
assert r2.status == 200, a2
assert a2["created_new"] is False
assert a2["asset_hash"] == a1["asset_hash"]
assert a2["id"] == a1["id"] # old reference
# Third upload with the same data but new name should return created_new == False and the new AssetReference
form3 = aiohttp.FormData()
form3.add_field("file", data, filename=name, content_type="application/octet-stream")
form3.add_field("tags", json.dumps(tags))
form3.add_field("name", name + "_d")
form3.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form3) as r2:
a3 = await r2.json()
assert r2.status == 200, a3
assert a3["created_new"] is False
assert a3["asset_hash"] == a1["asset_hash"]
assert a3["id"] != a1["id"] # old reference
@pytest.mark.asyncio
async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str):
# Seed a small file first
name = "fastpath_seed.safetensors"
tags = ["models", "checkpoints", "unit-tests"]
meta = {}
form1 = aiohttp.FormData()
form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream")
form1.add_field("tags", json.dumps(tags))
form1.add_field("name", name)
form1.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form1) as r1:
b1 = await r1.json()
assert r1.status == 201, b1
h = b1["asset_hash"]
# Now POST /api/assets with only hash and no file
form2 = aiohttp.FormData()
form2.add_field("hash", h)
form2.add_field("tags", json.dumps(tags))
form2.add_field("name", "fastpath_copy.safetensors")
form2.add_field("user_metadata", json.dumps({"purpose": "copy"}))
async with http.post(api_base + "/api/assets", data=form2) as r2:
b2 = await r2.json()
assert r2.status == 200, b2 # fast path returns 200 with created_new == False
assert b2["created_new"] is False
assert b2["asset_hash"] == h
@pytest.mark.asyncio
async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str):
payload = {
"hash": "blake3:" + "0" * 64,
"name": "nonexistent.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r:
body = await r.json()
assert r.status == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"