added create_asset_from_hash endpoint

This commit is contained in:
bigcat88
2025-08-24 14:15:21 +03:00
parent 0755e5320a
commit f2ea0bc22c
5 changed files with 221 additions and 25 deletions

View File

@@ -20,10 +20,18 @@ from .database.services import (
fetch_asset_info_and_asset,
touch_asset_info_by_id,
delete_asset_info_by_id,
asset_exists_by_hash,
get_asset_by_hash,
create_asset_info_for_existing_asset,
)
from .api import schemas_out
async def asset_exists(*, asset_hash: str) -> bool:
async with await create_session() as session:
return await asset_exists_by_hash(session, asset_hash=asset_hash)
def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None:
if not args.disable_model_processing:
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
@@ -69,14 +77,14 @@ async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> No
async def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str | None = "created_at",
order: str | None = "desc",
sort: str = "created_at",
order: str = "desc",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
@@ -157,9 +165,9 @@ async def resolve_asset_content_for_download(
async def update_asset(
*,
asset_info_id: int,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
name: Optional[str] = None,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
) -> schemas_out.AssetUpdated:
async with await create_session() as session:
info = await update_asset_info_full(
@@ -192,9 +200,49 @@ async def delete_asset_reference(*, asset_info_id: int) -> bool:
return r
async def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
) -> Optional[schemas_out.AssetCreated]:
canonical = hash_str.strip().lower()
async with await create_session() as session:
asset = await get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = await create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
added_by=None,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=info.asset_hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_hash=info.preview_hash,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
async def list_tags(
*,
prefix: str | None = None,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
@@ -217,26 +265,12 @@ async def list_tags(
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
async def add_tags_to_asset(
*,
asset_info_id: int,
tags: list[str],
origin: str = "manual",
added_by: str | None = None,
added_by: Optional[str] = None,
) -> schemas_out.TagsAdd:
async with await create_session() as session:
data = await add_tags_to_asset_info(
@@ -264,3 +298,24 @@ async def remove_tags_from_asset(
)
await session.commit()
return schemas_out.TagsRemove(**data)
def _safe_sort_field(requested: Optional[str]) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: Optional[str] , fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback