mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-03 04:09:57 +00:00
concurrency upload test + fixed 2 related bugs
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..._assets_helpers import normalize_tags
|
||||
@@ -13,13 +15,29 @@ async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_typ
|
||||
if not wanted:
|
||||
return []
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
existing_names = {t.name for t in existing}
|
||||
missing = [n for n in wanted if n not in existing_names]
|
||||
if missing:
|
||||
dialect = session.bind.dialect.name
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in missing]
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
await session.execute(ins)
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
by_name = {t.name: t for t in existing}
|
||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||
if to_create:
|
||||
session.add_all(to_create)
|
||||
await session.flush()
|
||||
by_name.update({t.name: t for t in to_create})
|
||||
return [by_name[n] for n in wanted]
|
||||
return [by_name[n] for n in wanted if n in by_name]
|
||||
|
||||
|
||||
async def add_missing_tag_for_asset_id(
|
||||
|
||||
@@ -484,6 +484,7 @@ async def ingest_fs_asset(
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
dialect = session.bind.dialect.name
|
||||
|
||||
if preview_id:
|
||||
if not await session.get(Asset, preview_id):
|
||||
@@ -502,10 +503,34 @@ async def ingest_fs_asset(
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
async with session.begin_nested():
|
||||
asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now)
|
||||
session.add(asset)
|
||||
await session.flush()
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
res = await session.execute(ins)
|
||||
rowcount = int(res.rowcount or 0)
|
||||
asset = (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
if rowcount > 0:
|
||||
out["asset_created"] = True
|
||||
else:
|
||||
changed = False
|
||||
@@ -524,7 +549,6 @@ async def ingest_fs_asset(
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(AssetCacheState)
|
||||
|
||||
Reference in New Issue
Block a user