diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py index f31f3a006..92aee6402 100644 --- a/app/assets/services/hashing.py +++ b/app/assets/services/hashing.py @@ -1,7 +1,8 @@ import io import os +from contextlib import contextmanager from dataclasses import dataclass -from typing import IO, Any, Callable +from typing import IO, Any, Callable, Iterator from blake3 import blake3 @@ -20,6 +21,29 @@ class HashCheckpoint: file_size: int = 0 +@contextmanager +def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]: + """Yield (file_object, is_path) with appropriate setup/teardown.""" + if hasattr(fp, "read"): + seekable = getattr(fp, "seekable", lambda: False)() + orig_pos = None + if seekable: + try: + orig_pos = fp.tell() + if orig_pos != 0: + fp.seek(0) + except io.UnsupportedOperation: + orig_pos = None + try: + yield fp, False + finally: + if orig_pos is not None: + fp.seek(orig_pos) + else: + with open(os.fspath(fp), "rb") as f: + yield f, True + + def compute_blake3_hash( fp: str | IO[bytes], chunk_size: int = DEFAULT_CHUNK, @@ -42,12 +66,11 @@ def compute_blake3_hash( (None, checkpoint) on interruption (file paths only), or (None, None) on interruption of a file object """ - if hasattr(fp, "read"): - digest = _hash_file_obj(fp, chunk_size, interrupt_check) - return digest, None + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK - with open(os.fspath(fp), "rb") as f: - if checkpoint is not None: + with _open_for_hashing(fp) as (f, is_path): + if checkpoint is not None and is_path: f.seek(checkpoint.bytes_processed) h = checkpoint.hasher bytes_processed = checkpoint.bytes_processed @@ -55,15 +78,14 @@ def compute_blake3_hash( h = blake3() bytes_processed = 0 - if chunk_size <= 0: - chunk_size = DEFAULT_CHUNK - while True: if interrupt_check is not None and interrupt_check(): - return None, HashCheckpoint( - bytes_processed=bytes_processed, - hasher=h, - ) + if is_path: + return None, HashCheckpoint( + bytes_processed=bytes_processed, + hasher=h, + ) + return None, None chunk = f.read(chunk_size) if not chunk: break @@ -71,38 +93,3 @@ def compute_blake3_hash( bytes_processed += len(chunk) return h.hexdigest(), None - - -def _hash_file_obj( - file_obj: IO, - chunk_size: int = DEFAULT_CHUNK, - interrupt_check: InterruptCheck | None = None, -) -> str | None: - if chunk_size <= 0: - chunk_size = DEFAULT_CHUNK - - seekable = getattr(file_obj, "seekable", lambda: False)() - orig_pos = None - - if seekable: - try: - orig_pos = file_obj.tell() - if orig_pos != 0: - file_obj.seek(0) - except io.UnsupportedOperation: - seekable = False - orig_pos = None - - try: - h = blake3() - while True: - if interrupt_check is not None and interrupt_check(): - return None - chunk = file_obj.read(chunk_size) - if not chunk: - break - h.update(chunk) - return h.hexdigest() - finally: - if seekable and orig_pos is not None: - file_obj.seek(orig_pos)