From ac6513e142f881202c40eacc5e337982b777ccd0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:19:40 -0800 Subject: [PATCH 01/45] DynamicVram: Add casting / fix torch Buffer weights (#12749) * respect model dtype in non-comfy caster * utils: factor out parent and name functionality of set_attr * utils: implement set_attr_buffer for torch buffers * ModelPatcherDynamic: Implement torch Buffer loading If there is a buffer in dynamic - force load it. --- comfy/model_management.py | 2 ++ comfy/model_patcher.py | 22 ++++++++++++++++++---- comfy/utils.py | 19 +++++++++++++++---- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0e0e96672..0f5966371 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -796,6 +796,8 @@ def archive_model_dtypes(model): for name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) def cleanup_models(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e380e406b..70f78a089 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -241,6 +241,7 @@ class ModelPatcher: self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -309,7 +310,7 @@ class ModelPatcher: return comfy.model_management.get_free_memory(device) def get_clone_model_override(self): - return self.model, (self.backup, self.object_patches_backup, self.pinned) + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ @@ -336,7 +337,7 @@ class ModelPatcher: n.force_cast_weights = self.force_cast_weights - n.backup, n.object_patches_backup, n.pinned = model_override[1] + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] # attachments n.attachments = {} @@ -1579,11 +1580,22 @@ class ModelPatcherDynamic(ModelPatcher): weight, _, _ = get_key_weight(self.model, key) if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) - comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) - self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() move_weight_functions(m, device_to) + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") @@ -1607,6 +1619,8 @@ class ModelPatcherDynamic(ModelPatcher): for key in list(self.backup.keys()): bk = self.backup.pop(key) comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) freed += self.model.model_loaded_weight_memory self.model.model_loaded_weight_memory = 0 diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..6e1d14419 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") From eb011733b6e4d8a9f7b67a1787d817bfc8c0a5b4 Mon Sep 17 00:00:00 2001 From: Arthur R Longbottom Date: Tue, 3 Mar 2026 21:29:00 -0800 Subject: [PATCH 02/45] Fix VideoFromComponents.save_to crash when writing to BytesIO (#12683) * Fix VideoFromComponents.save_to crash when writing to BytesIO When `get_container_format()` or `get_stream_source()` is called on a tensor-based video (VideoFromComponents), it calls `save_to(BytesIO())`. Since BytesIO has no file extension, `av.open` can't infer the output format and throws `ValueError: Could not determine output format`. The sibling class `VideoFromFile` already handles this correctly via `get_open_write_kwargs()`, which detects BytesIO and sets the format explicitly. `VideoFromComponents` just never got the same treatment. This surfaces when any downstream node validates the container format of a tensor-based video, like TopazVideoEnhance or any node that calls `validate_container_format_is_mp4()`. Three-line fix in `comfy_api/latest/_input_impl/video_types.py`. * Add docstring to save_to to satisfy CI coverage check --- comfy_api/latest/_input_impl/video_types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a3d48c87f..58a37c9e8 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput): codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, ): + """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput): extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value + elif isinstance(path, io.BytesIO): + # BytesIO has no file extension, so av.open can't infer the format. + # Default to mp4 since that's the only supported format anyway. + extra_kwargs["format"] = "mp4" with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: From d531e3fb2a885d675d5b6d3a496b4af5d9757af1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:47:44 -0800 Subject: [PATCH 03/45] model_patcher: Improve dynamic offload heuristic (#12759) Define a threshold below which a weight loading takes priority. This actually makes the offload consistent with non-dynamic, because what happens, is when non-dynamic fills ints to_load list, it will fill-up any left-over pieces that could fix large weights with small weights and load them, even though they were lower priority. This actually improves performance because the timy weights dont cost any VRAM and arent worth the control overhead of the DMA etc. --- comfy/model_patcher.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 70f78a089..168ce8430 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,7 +699,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False, default_device=None): + def _load_list(self, for_dynamic=False, default_device=None): loading = [] for n, m in self.model.named_modules(): default = False @@ -727,8 +727,13 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () - loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) + # Dynamic: small weights (<64KB) first, then larger weights prioritized by size. + # Non-dynamic: prioritize by module offload cost. + if for_dynamic: + sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem) + else: + sort_criteria = (module_offload_mem,) + loading.append(sort_criteria + (module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -1508,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) - loading.sort(reverse=True) + loading = self._load_list(for_dynamic=True, default_device=device_to) + loading.sort() for x in loading: - _, _, _, n, m, params = x + *_, module_mem, n, m, params = x def set_dirty(item, dirty): if dirty or not hasattr(item, "_v_signature"): @@ -1627,9 +1632,9 @@ class ModelPatcherDynamic(ModelPatcher): return freed def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) + loading = self._load_list(for_dynamic=True, default_device=self.offload_device) for x in loading: - _, _, _, _, m, _ = x + *_, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) if ram_to_unload <= 0: return From 9b85cf955858b0aca6b7b30c30b404470ea0c964 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:49:13 -0800 Subject: [PATCH 04/45] Comfy Aimdo 0.2.5 + Fix offload performance in DynamicVram (#12754) * ops: dont unpin nothing This was calling into aimdo in the none case (offloaded weight). Whats worse, is aimdo syncs for unpinning an offloaded weight, as that is the corner case of a weight getting evicted by its own use which does require a sync. But this was heppening every offloaded weight causing slowdown. * mp: fix get_free_memory policy The ModelPatcherDynamic get_free_memory was deducting the model from to try and estimate the conceptual free memory with doing any offloading. This is kind of what the old memory_memory_required was estimating in ModelPatcher load logic, however in practical reality, between over-estimates and padding, the loader usually underloaded models enough such that sampling could send CFG +/- through together even when partially loaded. So don't regress from the status quo and instead go all in on the idea that offloading is less of an issue than debatching. Tell the sampler it can use everything. --- comfy/model_patcher.py | 14 +++++++------- comfy/ops.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 168ce8430..7e5ad7aa4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -307,7 +307,13 @@ class ModelPatcher: return self.model.lowvram_patch_counter def get_free_memory(self, device): - return comfy.model_management.get_free_memory(device) + #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In + #the vast majority of setups a little bit of offloading on the giant model more + #than pays for CFG. So return everything both torch and Aimdo could give us + aimdo_mem = 0 + if comfy.memory_management.aimdo_enabled: + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) @@ -1465,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher): vbar = self._vbar_get() return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory - def get_free_memory(self, device): - #NOTE: on high condition / batch counts, estimate should have already vacated - #all non-dynamic models so this is safe even if its not 100% true that this - #would all be avaiable for inference use. - return comfy.model_management.get_total_memory(device) - self.model_size() - #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. def pin_weight_to_device(self, key): diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..8275dd0a5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream): return os, weight_a, bias_a = offload_stream device=None - #FIXME: This is not good RTTI - if not isinstance(weight_a, torch.Tensor): + #FIXME: This is really bad RTTI + if weight_a is not None and not isinstance(weight_a, torch.Tensor): comfy_aimdo.model_vbar.vbar_unpin(s._v) device = weight_a if os is None: diff --git a/requirements.txt b/requirements.txt index 608b0cfa6..110568cd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.4 +comfy-aimdo>=0.2.5 requests #non essential dependencies: From 0a7446ade4bbeecfaf36e9a70eeabbeb0f6e59ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:59:56 +0200 Subject: [PATCH 05/45] Pass tokens when loading text gen model for text generation (#12755) Co-authored-by: Jedrzej Kosinski --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index a9ad7c2d2..8bcd09582 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -428,7 +428,7 @@ class CLIP: def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): self.cond_stage_model.reset_clip_options() - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) From 8811db52db5d0aea49c1dbedd733a6b9304b83a9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:12:37 -0800 Subject: [PATCH 06/45] comfy-aimdo 0.2.6 (#12764) Comfy Aimdo 0.2.6 fixes a GPU virtual address leak. This would manfiest as an error after a number of workflow runs. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 110568cd3..dae46d873 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.5 +comfy-aimdo>=0.2.6 requests #non essential dependencies: From ac4a943ff364885166def5d418582db971554caf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:33:14 -0800 Subject: [PATCH 07/45] Initial load device should be cpu when using dynamic vram. (#12766) --- comfy/model_management.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0f5966371..809600815 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -830,11 +830,14 @@ def unet_offload_device(): return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + cpu_dev = torch.device("cpu") + if comfy.memory_management.aimdo_enabled: + return cpu_dev + torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: return torch_dev - cpu_dev = torch.device("cpu") if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev @@ -842,7 +845,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: + if mem_dev > mem_cpu and model_size < mem_dev: return torch_dev else: return cpu_dev @@ -945,6 +948,9 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): + if comfy.memory_management.aimdo_enabled: + return offload_device + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device From 43c64b6308f93c331f057e12799bad0a68be5117 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:06:20 -0800 Subject: [PATCH 08/45] Support the LTXAV 2.3 model. (#12773) --- comfy/ldm/lightricks/av_model.py | 185 ++++++- comfy/ldm/lightricks/embeddings_connector.py | 4 + comfy/ldm/lightricks/model.py | 186 ++++++- comfy/ldm/lightricks/vae/audio_vae.py | 7 +- .../vae/causal_audio_autoencoder.py | 67 +-- .../vae/causal_video_autoencoder.py | 48 +- comfy/ldm/lightricks/vocoders/vocoder.py | 523 +++++++++++++++++- comfy/model_base.py | 2 +- comfy/sd.py | 2 +- comfy/text_encoders/lt.py | 68 ++- 10 files changed, 959 insertions(+), 133 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 553fd5b38..08d686b7b 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -2,11 +2,16 @@ from typing import Tuple import torch import torch.nn as nn from comfy.ldm.lightricks.model import ( + ADALN_BASE_PARAMS_COUNT, + ADALN_CROSS_ATTN_PARAMS_COUNT, CrossAttention, FeedForward, AdaLayerNormSingle, PixArtAlphaTextProjection, + NormSingleLinearTextProjection, LTXVModel, + apply_cross_attention_adaln, + compute_prompt_timestep, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module): v_context_dim=None, a_context_dim=None, attn_precision=None, + apply_gated_attention=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=v_dim, @@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=vd_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=ad_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module): heads=v_heads, dim_head=vd_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module): a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations ) - self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype)) self.audio_scale_shift_table = nn.Parameter( - torch.empty(6, a_dim, device=device, dtype=dtype) + torch.empty(num_ada_params, a_dim, device=device, dtype=dtype) ) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype)) + self.scale_shift_table_a2v_ca_audio = nn.Parameter( torch.empty(5, a_dim, device=device, dtype=dtype) ) @@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module): return (*scale_shift_ada_values, *gate_ada_values) + def _apply_text_cross_attention( + self, x, context, attn, scale_shift_table, prompt_scale_shift_table, + timestep, prompt_timestep, attention_mask, transformer_options, + ): + """Apply text cross-attention, with optional ADaLN modulation.""" + if self.cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values( + scale_shift_table, x.shape[0], timestep, slice(6, 9) + ) + return apply_cross_attention_adaln( + x, context, attn, shift_q, scale_q, gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask, transformer_options, + ) + return attn( + comfy.ldm.common_dit.rms_norm(x), context=context, + mask=attention_mask, transformer_options=transformer_options, + ) + def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, + v_prompt_timestep=None, a_prompt_timestep=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module): vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] vx.addcmul_(attn1_out, vgate_msa) del vgate_msa, attn1_out - vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + vx.add_(self._apply_text_cross_attention( + vx, v_context, self.attn2, self.scale_shift_table, + getattr(self, 'prompt_scale_shift_table', None), + v_timestep, v_prompt_timestep, attention_mask, transformer_options,) + ) # audio if run_ax: @@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module): agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] ax.addcmul_(attn1_out, agate_msa) del agate_msa, attn1_out - ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) + ax.add_(self._apply_text_cross_attention( + ax, a_context, self.audio_attn2, self.audio_scale_shift_table, + getattr(self, 'audio_prompt_scale_shift_table', None), + a_timestep, a_prompt_timestep, attention_mask, transformer_options,) + ) # video - audio cross attention. if run_a2v or run_v2a: @@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel): use_middle_indices_grid=False, timestep_scale_multiplier=1000.0, av_ca_timestep_scale_multiplier=1.0, + apply_gated_attention=False, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel): self.audio_attention_head_dim = audio_attention_head_dim self.audio_num_attention_heads = audio_num_attention_heads self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.apply_gated_attention = apply_gated_attention # Calculate audio dimensions self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim @@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel): ) # Audio-specific AdaLN + audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.audio_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, + embedding_coefficient=audio_embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations, ) + if self.cross_attention_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=2, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_prompt_adaln_single = None + num_scale_shift_values = 4 self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( self.inner_dim, @@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel): ) # Audio caption projection - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.audio_inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.audio_caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_caption_projection = lambda a: a + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + connector_split_rope = kwargs.get("rope_type", "split") == "split" + connector_gated_attention = kwargs.get("connector_apply_gated_attention", False) + attention_head_dim = kwargs.get("connector_attention_head_dim", 128) + num_attention_heads = kwargs.get("connector_num_attention_heads", 30) + num_layers = kwargs.get("connector_num_layers", 2) self.audio_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim), + num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads), + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) self.video_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) - def preprocess_text_embeds(self, context): - if context.shape[-1] == self.caption_channels * 2: - return context - out_vid = self.video_embeddings_connector(context)[0] - out_audio = self.audio_embeddings_connector(context)[0] + def preprocess_text_embeds(self, context, unprocessed=False): + # LTXv2 fully processed context has dimension of self.caption_channels * 2 + # LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim + if not unprocessed: + if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2): + return context + if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim: + context_vid = context[:, :, :self.cross_attention_dim] + context_audio = context[:, :, self.cross_attention_dim:] + else: + context_vid = context + context_audio = context + if self.caption_proj_before_connector: + context_vid = self.caption_projection(context_vid) + context_audio = self.audio_caption_projection(context_audio) + out_vid = self.video_embeddings_connector(context_vid)[0] + out_audio = self.audio_embeddings_connector(context_audio)[0] return torch.concat((out_vid, out_audio), dim=-1) def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel): ad_head=self.audio_attention_head_dim, v_context_dim=self.cross_attention_dim, a_context_dim=self.audio_cross_attention_dim, + apply_gated_attention=self.apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel): v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + v_prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: @@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel): # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep_flat, + timestep.max().expand_as(a_timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep_flat, + a_timestep.max().expand_as(timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep_flat * av_ca_factor, + a_timestep.max().expand_as(timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep_flat * av_ca_factor, + timestep.max().expand_as(a_timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel): # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) + + a_prompt_timestep = compute_prompt_timestep( + self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype + ) else: a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] + a_prompt_timestep = None - return [v_timestep, a_timestep, cross_av_timestep_ss], [ + return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [ v_embedded_timestep, a_embedded_timestep, - ] + ], None def _prepare_context(self, context, batch_size, x, attention_mask=None): vx = x[0] ax = x[1] + video_dim = vx.shape[-1] + audio_dim = ax.shape[-1] + + v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim + a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim + v_context, a_context = torch.split( - context, int(context.shape[-1] / 2), len(context.shape) - 1 + context, [v_context_dim, a_context_dim], len(context.shape) - 1 ) v_context, attention_mask = super()._prepare_context( v_context, batch_size, vx, attention_mask ) - if self.audio_caption_projection is not None: + if self.caption_proj_before_connector is False: a_context = self.audio_caption_projection(a_context) - a_context = a_context.view(batch_size, -1, ax.shape[-1]) + a_context = a_context.view(batch_size, -1, audio_dim) return [v_context, a_context], attention_mask @@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel): av_ca_v2a_gate_noise_timestep, ) = timestep[2] + v_prompt_timestep = timestep[3] + a_prompt_timestep = timestep[4] + """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), + v_prompt_timestep=args.get("v_prompt_timestep"), + a_prompt_timestep=args.get("a_prompt_timestep"), ) return out @@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel): "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, + "v_prompt_timestep": v_prompt_timestep, + "a_prompt_timestep": a_prompt_timestep, }, {"original_block": block_wrap}, ) @@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + v_prompt_timestep=v_prompt_timestep, + a_prompt_timestep=a_prompt_timestep, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index 33adb9671..2811080be 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module): d_head, context_dim=None, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module): heads=n_heads, dim_head=d_head, context_dim=None, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module): positional_embedding_max_pos=[4096], causal_temporal_positioning=False, num_learnable_registers: Optional[int] = 128, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module): num_attention_heads, attention_head_dim, context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 60d760d29..bfbc08357 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states +class NormSingleLinearTextProjection(nn.Module): + """Text projection for 20B models - single linear with RMSNorm (no activation).""" + + def __init__( + self, in_features, hidden_size, dtype=None, device=None, operations=None + ): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.in_norm = operations.RMSNorm( + in_features, eps=1e-6, elementwise_affine=False + ) + self.linear_1 = operations.Linear( + in_features, hidden_size, bias=True, dtype=dtype, device=device + ) + self.hidden_size = hidden_size + self.in_features = in_features + + def forward(self, caption): + caption = self.in_norm(caption) + caption = caption * (self.hidden_size / self.in_features) ** 0.5 + return self.linear_1(caption) + + class GELU_approx(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): super().__init__() @@ -343,6 +367,7 @@ class CrossAttention(nn.Module): dim_head=64, dropout=0.0, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -362,6 +387,12 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device) + else: + self.to_gate_logits = None + self.to_out = nn.Sequential( operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -383,16 +414,30 @@ class CrossAttention(nn.Module): out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + out = out.view(b, t, self.heads, self.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity + out = out * gates.unsqueeze(-1) + out = out.view(b, t, self.heads * self.dim_head) + return self.to_out(out) +# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate) +ADALN_BASE_PARAMS_COUNT = 6 +ADALN_CROSS_ATTN_PARAMS_COUNT = 9 class BasicTransformerBlock(nn.Module): def __init__( - self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None ): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, @@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module): operations=operations, ) - self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) - attn1_input = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) - x.addcmul_(attn1_input, gate_msa) - del attn1_input + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2) - x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa + + if self.cross_attention_adaln: + shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2) + x += apply_cross_attention_adaln( + x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca, + self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options, + ) + else: + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) @@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module): return x +def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype): + """Compute a single global prompt timestep for cross-attention ADaLN. + + Uses the max across tokens (matching JAX max_per_segment) and broadcasts + over text tokens. Returns None when *adaln_module* is None. + """ + if adaln_module is None: + return None + ts_input = ( + timestep_scaled.max(dim=1, keepdim=True).values.flatten() + if timestep_scaled.dim() > 1 + else timestep_scaled.flatten() + ) + prompt_ts, _ = adaln_module( + ts_input, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1]) + + +def apply_cross_attention_adaln( + x, context, attn, q_shift, q_scale, q_gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask=None, transformer_options={}, +): + """Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV). + + Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so + that both regular tensors and CompressedTimestep are supported. + """ + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate + def get_fractional_positions(indices_grid, max_pos): n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' @@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC): vae_scale_factors: tuple = (8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier = 1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, + caption_projection_first_linear=True, dtype=None, device=None, operations=None, @@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC): self.causal_temporal_positioning = causal_temporal_positioning self.operations = operations self.timestep_scale_multiplier = timestep_scale_multiplier + self.caption_proj_before_connector = caption_proj_before_connector + self.cross_attention_adaln = cross_attention_adaln + self.caption_projection_first_linear = caption_projection_first_linear # Common dimensions self.inner_dim = num_attention_heads * attention_head_dim @@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC): self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device ) + embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.cross_attention_adaln: + self.prompt_adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + ) + else: + self.prompt_adaln_single = None + + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.caption_projection = lambda a: a + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) @abstractmethod def _init_model_components(self, device, dtype, **kwargs): @@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC): timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) - return timestep, embedded_timestep + prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + + return timestep, embedded_timestep, prompt_timestep def _prepare_context(self, context, batch_size, x, attention_mask=None): """Prepare context for transformer blocks.""" - if self.caption_projection is not None: + if self.caption_proj_before_connector is False: context = self.caption_projection(context) - context = context.view(batch_size, -1, x.shape[-1]) + context = context.view(batch_size, -1, x.shape[-1]) return context, attention_mask def _precompute_freqs_cis( @@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC): merged_args.update(additional_args) # Prepare timestep and context - timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + merged_args["prompt_timestep"] = prompt_timestep context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) # Prepare attention mask and positional embeddings @@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel): causal_temporal_positioning=False, vae_scale_factors=(8, 32, 32), use_middle_indices_grid=False, - timestep_scale_multiplier = 1000.0, + timestep_scale_multiplier=1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel): def _init_model_components(self, device, dtype, **kwargs): """Initialize LTXV-specific components.""" - # No additional components needed for LTXV beyond base class pass def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel): self.num_attention_heads, self.attention_head_dim, context_dim=self.cross_attention_dim, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prompt_timestep = kwargs.get("prompt_timestep", None) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel): pe=pe, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + prompt_timestep=prompt_timestep, ) return x diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 55a074661..fa0a00748 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( CausalityAxis, CausalAudioAutoencoder, ) -from comfy.ldm.lightricks.vocoders.vocoder import Vocoder +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE LATENT_DOWNSAMPLE_FACTOR = 4 @@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module): vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) - self.vocoder = Vocoder(config=component_config.vocoder) + if "bwe" in component_config.vocoder: + self.vocoder = VocoderWithBWE(config=component_config.vocoder) + else: + self.vocoder = Vocoder(config=component_config.vocoder) self.autoencoder.load_state_dict(vae_sd, strict=False) self.vocoder.load_state_dict(vocoder_sd, strict=False) diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py index f12b9bb53..b556b128f 100644 --- a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module): super().__init__() if config is None: - config = self._guess_config() + config = self.get_default_config() - # Extract encoder and decoder configs from the new format model_config = config.get("model", {}).get("params", {}) - variables_config = config.get("variables", {}) - self.sampling_rate = variables_config.get( - "sampling_rate", - model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + self.sampling_rate = model_config.get( + "sampling_rate", config.get("sampling_rate", 16000) ) encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) decoder_config = model_config.get("decoder", encoder_config) # Load mel spectrogram parameters self.mel_bins = encoder_config.get("mel_bins", 64) - self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) - self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) # Store causality configuration at VAE level (not just in encoder internals) - causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value) self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) self.is_causal = self.causality_axis == CausalityAxis.HEIGHT @@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module): self.per_channel_statistics = processor() - def _guess_config(self): - encoder_config = { - # Required parameters - based on ltx-video-av-1679000 model metadata - "ch": 128, - "out_ch": 8, - "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] - "num_res_blocks": 2, - "attn_resolutions": [], # Based on metadata: empty list, no attention - "dropout": 0.0, - "resamp_with_conv": True, - "in_channels": 2, # stereo - "resolution": 256, - "z_channels": 8, + def get_default_config(self): + ddconfig = { "double_z": True, - "attn_type": "vanilla", - "mid_block_add_attention": False, # Based on metadata: false + "mel_bins": 64, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 2, + "out_ch": 2, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "mid_block_add_attention": False, "norm_type": "pixel", - "causality_axis": "height", # Based on metadata - "mel_bins": 64, # Based on metadata: mel_bins = 64 - } - - decoder_config = { - # Inherits encoder config, can override specific params - **encoder_config, - "out_ch": 2, # Stereo audio output (2 channels) - "give_pre_end": False, - "tanh_out": False, + "causality_axis": "height", } config = { - "_class_name": "CausalAudioAutoencoder", - "sampling_rate": 16000, "model": { "params": { - "encoder": encoder_config, - "decoder": decoder_config, + "ddconfig": ddconfig, + "sampling_rate": 16000, } }, + "preprocessing": { + "stft": { + "filter_length": 1024, + "hop_length": 160, + }, + }, } return config diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index cbfdf412d..5b57dfc5e 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def in_meta_context(): + return torch.device("meta") == torch.empty(0).device + def mark_conv3d_ended(module): tid = threading.get_ident() for _, m in module.named_modules(): @@ -350,6 +353,10 @@ class Decoder(nn.Module): output_channel = output_channel * block_params.get("multiplier", 2) if block_name == "compress_all": output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_space": + output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_time": + output_channel = output_channel * block_params.get("multiplier", 1) self.conv_in = make_conv_nd( dims, @@ -395,17 +402,21 @@ class Decoder(nn.Module): spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(2, 1, 1), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(1, 2, 2), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": @@ -455,6 +466,15 @@ class Decoder(nn.Module): output_channel * 2, 0, operations=ops, ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + else: + self.register_buffer( + "last_scale_shift_table", + torch.tensor( + [0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(2, output_channel), + persistent=False, + ) # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: @@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module): self.scale_shift_table = nn.Parameter( torch.randn(4, in_channels) / in_channels**0.5 ) + else: + self.register_buffer( + "scale_shift_table", + torch.tensor( + [0.0, 0.0, 0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(4, in_channels), + persistent=False, + ) self.temporal_cache_state={} @@ -1012,9 +1041,6 @@ class processor(nn.Module): super().__init__() self.register_buffer("std-of-means", torch.empty(128)) self.register_buffer("mean-of-means", torch.empty(128)) - self.register_buffer("mean-of-stds", torch.empty(128)) - self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) - self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) @@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module): super().__init__() if config is None: - config = self.guess_config(version) + config = self.get_default_config(version) + self.config = config self.timestep_conditioning = config.get("timestep_conditioning", False) + self.decode_noise_scale = config.get("decode_noise_scale", 0.025) + self.decode_timestep = config.get("decode_timestep", 0.05) double_z = config.get("double_z", True) latent_log_var = config.get( "latent_log_var", "per_channel" if double_z else "none" @@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module): latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + base_channels=config.get("encoder_base_channels", 128), ) self.decoder = Decoder( @@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module): in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + base_channels=config.get("decoder_base_channels", 128), patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), @@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module): self.per_channel_statistics = processor() - def guess_config(self, version): + def get_default_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module): means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x, timestep=0.05, noise_scale=0.025): + def decode(self, x): if self.timestep_conditioning: #TODO: seed - x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep) - + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index b1f15f2c5..6c4028aa8 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import torch.nn as nn import comfy.ops import numpy as np +import math ops = comfy.ops.disable_weight_init @@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor): + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride=1, + padding=True, + padding_mode="replicate", + kernel_size=12, + ): + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + def forward(self, x): + _, C, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"): + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter — identical to torchaudio.functional.resample + # with its default parameters (rolloff=0.99, lowpass_filter_width=6). + # Uses replicate boundary padding, matching the reference resampler exactly. + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + t = (torch.arange(self.kernel_size) / ratio - width) * rolloff + t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2 + filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + + self.register_buffer("filter", filter, persistent=persistent) + + def forward(self, x): + _, C, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio=2, + down_ratio=2, + up_kernel_size=12, + down_kernel_size=12, + ): + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +# --------------------------------------------------------------------------- +# BigVGAN v2 activations (Snake / SnakeBeta) +# --------------------------------------------------------------------------- + + +class Snake(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + b = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + b = torch.exp(b) + return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2) + + +# --------------------------------------------------------------------------- +# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity) +# --------------------------------------------------------------------------- + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"): + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + self.acts1 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs1))] + ) + self.acts2 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs2))] + ) + + def forward(self, x): + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +# --------------------------------------------------------------------------- +# HiFi-GAN residual blocks +# --------------------------------------------------------------------------- + + class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() @@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module): """ Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1"). """ def __init__(self, config=None): @@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module): config = self.get_default_config() resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) - upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) - upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4]) resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_initial_channel = config.get("upsample_initial_channel", 1024) stereo = config.get("stereo", True) - resblock = config.get("resblock", "1") + activation = config.get("activation", "snake") + use_bias_at_final = config.get("use_bias_at_final", True) + + # "output_sample_rate" is not present in recent checkpoint configs. + # When absent (None), AudioVAE.output_sample_rate computes it as: + # sample_rate * vocoder.upsample_factor / mel_hop_length + # where upsample_factor = product of all upsample stride lengths, + # and mel_hop_length is loaded from the autoencoder config at + # preprocessing.stft.hop_length (see CausalAudioAutoencoder). self.output_sample_rate = config.get("output_sample_rate") + self.resblock = config.get("resblock", "1") + self.use_tanh_at_final = config.get("use_tanh_at_final", True) + self.apply_final_activation = config.get("apply_final_activation", True) self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + if self.resblock == "1": + resblock_cls = ResBlock1 + elif self.resblock == "2": + resblock_cls = ResBlock2 + elif self.resblock == "AMP1": + resblock_cls = AMPBlock1 + else: + raise ValueError(f"Unknown resblock type: {self.resblock}") self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock_class(ch, k, d)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): + if self.resblock == "AMP1": + self.resblocks.append(resblock_cls(ch, k, d, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, k, d)) out_channels = 2 if stereo else 1 - self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + if self.resblock == "AMP1": + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.act_post = Activation1d(act_cls(ch)) + else: + self.act_post = nn.LeakyReLU() + + self.conv_post = ops.Conv1d( + ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final + ) self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + def get_default_config(self): """Generate default configuration for the vocoder.""" config = { "resblock_kernel_sizes": [3, 7, 11], - "upsample_rates": [6, 5, 2, 2, 2], - "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "upsample_initial_channel": 1024, "stereo": True, "resblock": "1", + "activation": "snake", + "use_bias_at_final": True, + "use_tanh_at_final": True, } return config @@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module): assert x.shape[1] == 2, "Input must have 2 channels for stereo" x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) x = self.conv_pre(x) + for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) + if self.resblock != "AMP1": + x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): @@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module): else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels - x = F.leaky_relu(x) + + x = self.act_post(x) x = self.conv_post(x) - x = torch.tanh(x) + + if self.apply_final_activation: + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, -1, 1) return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT × Hann-window bases. + + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int): + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input — no lookahead. + The STFT is computed by convolving the padded signal with forward_basis. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + Computed in float32 for numerical stability, then cast back to + the input dtype. + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real ** 2 + imag ** 2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + sampling_rate: int, + mel_fmin: float, + mel_fmax: float, + ): + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram( + self, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + Computed as log(clamp(mel_basis @ magnitude, min=1e-5)). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class VocoderWithBWE(torch.nn.Module): + """Vocoder with bandwidth extension (BWE) for higher sample rate output. + + Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples + to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. + """ + + def __init__(self, config): + super().__init__() + vocoder_config = config["vocoder"] + bwe_config = config["bwe"] + + self.vocoder = Vocoder(config=vocoder_config) + self.bwe_generator = Vocoder( + config={**bwe_config, "apply_final_activation": False} + ) + + self.input_sample_rate = bwe_config["input_sampling_rate"] + self.output_sample_rate = bwe_config["output_sampling_rate"] + self.hop_length = bwe_config["hop_length"] + + self.mel_stft = MelSTFT( + filter_length=bwe_config["n_fft"], + hop_length=bwe_config["hop_length"], + win_length=bwe_config["n_fft"], + n_mel_channels=bwe_config["num_mels"], + sampling_rate=bwe_config["input_sampling_rate"], + mel_fmin=0.0, + mel_fmax=bwe_config["input_sampling_rate"] / 2.0, + ) + self.resampler = UpSample1d( + ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"], + persistent=False, + window_type="hann", + ) + + def _compute_mel(self, audio): + """Compute log-mel spectrogram from waveform using causal STFT bases.""" + B, C, T = audio.shape + flat = audio.reshape(B * C, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec): + x = self.vocoder(mel_spec) + _, _, T_low = x.shape + T_out = T_low * self.output_sample_rate // self.input_sample_rate + + remainder = T_low % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + mel = self._compute_mel(x) + residual = self.bwe_generator(mel) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :T_out] diff --git a/comfy/model_base.py b/comfy/model_base.py index 1e01e9edc..d9d5a9293 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1021,7 +1021,7 @@ class LTXAV(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: if hasattr(self.diffusion_model, "preprocess_text_embeds"): - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) diff --git a/comfy/sd.py b/comfy/sd.py index 8bcd09582..888ef1e77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage elif clip_type == CLIPType.LTXV: - clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e86ea9f4e..5e1273c6e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is +class DualLinearProjection(torch.nn.Module): + def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None): + super().__init__() + self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device) + self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device) + + def forward(self, x): + source_dim = x.shape[-1] + x = x.movedim(1, -1) + x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2) + + video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim)) + audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim)) + return torch.cat((video, audio), dim=-1) + class LTXAVTEModel(torch.nn.Module): - def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}): super().__init__() self.dtypes = set() self.dtypes.add(dtype) self.compat_mode = False + self.text_projection_type = text_projection_type self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) self.dtypes.add(dtype_llama) operations = self.gemma3_12b.operations # TODO - self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + if self.text_projection_type == "single_linear": + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + elif self.text_projection_type == "dual_linear": + self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations) + def enable_compat_mode(self): # TODO: remove from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module): out_device = out.device if comfy.model_management.should_use_bf16(self.execution_device): out = out.to(device=self.execution_device, dtype=torch.bfloat16) - out = out.movedim(1, -1).to(self.execution_device) - out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) - out = out.reshape((out.shape[0], out.shape[1], -1)) - out = self.text_embedding_projection(out) - out = out.float() - if self.compat_mode: - out_vid = self.video_embeddings_connector(out)[0] - out_audio = self.audio_embeddings_connector(out)[0] - out = torch.concat((out_vid, out_audio), dim=-1) + if self.text_projection_type == "single_linear": + out = out.movedim(1, -1).to(self.execution_device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) - return out.to(out_device), pooled + if self.compat_mode: + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + extra = {} + else: + extra = {"unprocessed_ltxav_embeds": True} + elif self.text_projection_type == "dual_linear": + out = self.text_embedding_projection(out) + extra = {"unprocessed_ltxav_embeds": True} + + return out.to(device=out_device, dtype=torch.float), pooled, extra def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) @@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) else: - sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True) if len(sdo) == 0: sdo = sd @@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module): num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 -def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: @@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options) return LTXAVTEModel_ + +def sd_detect(state_dict_list, prefix=""): + for sd in state_dict_list: + if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd: + return {"text_projection_type": "dual_linear"} + if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd: + return {"text_projection_type": "single_linear"} + return {} + + def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): class Gemma3_12BModel_(Gemma3_12BModel): def __init__(self, device="cpu", dtype=None, model_options={}): From f2ee7f2d367f98bb8a33bcb4a224bda441eb8a07 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 22:21:55 -0800 Subject: [PATCH 09/45] Fix cublas ops on dynamic vram. (#12776) --- comfy/ops.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 8275dd0a5..3e19cd1b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -660,23 +660,29 @@ class fp8_ops(manual_cast): CUBLAS_IS_AVAILABLE = False try: - from cublas_ops import CublasLinear + from cublas_ops import CublasLinear, cublas_half_matmul CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: - class cublas_ops(disable_weight_init): - class Linear(CublasLinear, disable_weight_init.Linear): + class cublas_ops(manual_cast): + class Linear(CublasLinear, manual_cast.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - return super().forward(input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) # ============================================================================== # Mixed Precision Operations From c5fe8ace68c432a262a5093bdd84b3ed70b9d283 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 15:37:35 +0800 Subject: [PATCH 10/45] chore: update workflow templates to v0.9.6 (#12778) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dae46d873..5f99407b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.5 +comfyui-workflow-templates==0.9.6 comfyui-embedded-docs==0.4.3 torch torchsde From 4941671b5a5c65fea48be922caa76b7f6a0a4595 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:39:51 -0800 Subject: [PATCH 11/45] Fix cuda getting initialized in cpu mode. (#12779) --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 809600815..ee28ea107 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1666,12 +1666,16 @@ def lora_compute_dtype(device): return dtype def synchronize(): + if cpu_mode(): + return if is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() def soft_empty_cache(force=False): + if cpu_mode(): + return global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() From c8428541a6b6e4b1e0fbd685e9c846efcb60179e Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 16:58:25 +0800 Subject: [PATCH 12/45] chore: update workflow templates to v0.9.7 (#12780) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5f99407b7..866818e08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.6 +comfyui-workflow-templates==0.9.7 comfyui-embedded-docs==0.4.3 torch torchsde From e04d0dbeb8266aa9262b5a4c3934ba4e4a371e37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 04:06:29 -0500 Subject: [PATCH 13/45] ComfyUI v0.16.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 6a35c6de3..0aea18d3a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.1" +__version__ = "0.16.0" diff --git a/pyproject.toml b/pyproject.toml index 1b2318273..f2133d99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.1" +version = "0.16.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From bd21363563ce8e312c9271a0c64a0145335df8a9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:29:39 +0200 Subject: [PATCH 14/45] feat(api-nodes-xAI): updated models, pricing, added features (#12756) --- comfy_api_nodes/apis/grok.py | 14 ++++-- comfy_api_nodes/nodes_grok.py | 92 +++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py index 8e3c79ab9..c56c8aecc 100644 --- a/comfy_api_nodes/apis/grok.py +++ b/comfy_api_nodes/apis/grok.py @@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel): aspect_ratio: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + resolution: str = Field(...) class InputUrlObject(BaseModel): @@ -16,12 +17,13 @@ class InputUrlObject(BaseModel): class ImageEditRequest(BaseModel): model: str = Field(...) - image: InputUrlObject = Field(...) + images: list[InputUrlObject] = Field(...) prompt: str = Field(...) resolution: str = Field(...) n: int = Field(...) seed: int = Field(...) - response_for: str = Field("url") + response_format: str = Field("url") + aspect_ratio: str | None = Field(...) class VideoGenerationRequest(BaseModel): @@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel): revised_prompt: str | None = Field(None) +class UsageObject(BaseModel): + cost_in_usd_ticks: int | None = Field(None) + + class ImageGenerationResponse(BaseModel): data: list[ImageResponseObject] = Field(...) + usage: UsageObject | None = Field(None) class VideoGenerationResponse(BaseModel): @@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel): status: str | None = Field(None) video: VideoResponseObject | None = Field(None) model: str | None = Field(None) + usage: UsageObject | None = Field(None) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index da15e97ea..0716d6239 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -27,6 +27,12 @@ from comfy_api_nodes.util import ( ) +def _extract_grok_price(response) -> float | None: + if response.usage and response.usage.cost_in_usd_ticks is not None: + return response.usage.cost_in_usd_ticks / 10_000_000_000 + return None + + class GrokImageNode(IO.ComfyNode): @classmethod @@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode): category="api node/image/Grok", description="Generate images using Grok based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), IO.String.Input( "prompt", multiline=True, @@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input("resolution", options=["1K", "2K"], optional=True), ], outputs=[ IO.Image.Output(), @@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": $rate * widgets.number_of_images} + ) + """, ), ) @@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio: str, number_of_images: int, seed: int, + resolution: str = "1K", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) response = await sync_op( @@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode): aspect_ratio=aspect_ratio, n=number_of_images, seed=seed, + resolution=resolution.lower(), ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode): category="api node/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-image-beta"]), - IO.Image.Input("image"), + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), + IO.Image.Input("image", display_name="images"), IO.String.Input( "prompt", multiline=True, tooltip="The text prompt used to generate the image", ), - IO.Combo.Input("resolution", options=["1K"]), + IO.Combo.Input("resolution", options=["1K", "2K"]), IO.Int.Input( "number_of_images", default=1, @@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + optional=True, + tooltip="Only allowed when multiple images are connected to the image input.", + ), ], outputs=[ IO.Image.Output(), @@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), - expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": 0.002 + $rate * widgets.number_of_images} + ) + """, ), ) @@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode): resolution: str, number_of_images: int, seed: int, + aspect_ratio: str = "auto", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Only one input image is supported.") + if model == "grok-imagine-image-pro": + if get_number_of_images(image) > 1: + raise ValueError("The pro model supports only 1 input image.") + elif get_number_of_images(image) > 3: + raise ValueError("A maximum of 3 input images is supported.") + if aspect_ratio != "auto" and get_number_of_images(image) == 1: + raise ValueError( + "Custom aspect ratio is only allowed when multiple images are connected to the image input." + ) response = await sync_op( cls, ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), data=ImageEditRequest( model=model, - image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), + images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image], prompt=prompt, resolution=resolution.lower(), n=number_of_images, seed=seed, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, ), response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) @@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode): category="api node/video/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), expr=""" ( - $base := 0.181 * widgets.duration; + $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $base := $rate * widgets.duration; {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} ) """, @@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) @@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode): category="api node/video/Grok", description="Edit an existing video based on a text prompt.", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), IO.String.Input( "prompt", multiline=True, @@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", + expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""", ), ) @@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) From 9cdfd7403bc46f75d12be16ba6041b8bcdd3f7fd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:12:38 +0200 Subject: [PATCH 15/45] feat(api-nodes): enable Kling 3.0 Motion Control (#12785) --- comfy_api_nodes/apis/kling.py | 1 + comfy_api_nodes/nodes_kling.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py index a5bd5f1d3..fe0f97cb3 100644 --- a/comfy_api_nodes/apis/kling.py +++ b/comfy_api_nodes/apis/kling.py @@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel): keep_original_sound: str = Field(...) character_orientation: str = Field(...) mode: str = Field(..., description="'pro' or 'std'") + model_name: str = Field(...) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 74fa078ff..8963c335d 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode): "but the character orientation matches the reference image (camera/other details via prompt).", ), IO.Combo.Input("mode", options=["pro", "std"]), + IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True), ], outputs=[ IO.Video.Output(), @@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode): keep_original_sound: bool, character_orientation: str, mode: str, + model: str = "kling-v2-6", ) -> IO.NodeOutput: validate_string(prompt, max_length=2500) validate_image_dimensions(reference_image, min_width=340, min_height=340) @@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode): keep_original_sound="yes" if keep_original_sound else "no", character_orientation=character_orientation, mode=mode, + model_name=model, ), ) if response.code: From da29b797ce00b491c269e864cc3b8fceb279e530 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 5 Mar 2026 23:23:23 +0800 Subject: [PATCH 16/45] Update workflow templates to v0.9.8 (#12788) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 866818e08..3fd44e0cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.7 +comfyui-workflow-templates==0.9.8 comfyui-embedded-docs==0.4.3 torch torchsde From 6ef82a89b83a49247081dc57b154172573c9e313 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 10:38:33 -0500 Subject: [PATCH 17/45] ComfyUI v0.16.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 0aea18d3a..e58e0fb63 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.0" +__version__ = "0.16.1" diff --git a/pyproject.toml b/pyproject.toml index f2133d99c..199a90364 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.0" +version = "0.16.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 6481569ad4c3606bc50e9de39ce810651690ae79 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:04:24 -0800 Subject: [PATCH 18/45] comfy-aimdo 0.2.7 (#12791) Comfy-aimdo 0.2.7 fixes a crash when a spurious cudaAsyncFree comes in and would cause an infinite stack overflow (via detours hooks). A lock is also introduced on the link list holding the free sections to avoid any possibility of threaded miscellaneous cuda allocations being the root cause. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3fd44e0cf..f7098b730 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.6 +comfy-aimdo>=0.2.7 requests #non essential dependencies: From 42e0e023eee6a19c1adb7bd3dc11c81ff6dcc9c8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:22:17 -0800 Subject: [PATCH 19/45] ops: Handle CPU weight in VBAR caster (#12792) This shouldn't happen but custom nodes gets there. Handle it as best we can. --- comfy/ops.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/comfy/ops.py b/comfy/ops.py index 3e19cd1b6..06aa41d4f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if device is not None and device.type == "cpu": + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = None + if s.bias is not None: + bias = s.bias.to(dtype=bias_dtype, copy=True) + return weight, bias, (None, None, None) + offload_stream = None xfer_dest = None From 5073da57ad20a2abb921f79458e49a7f7d608740 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 6 Mar 2026 02:22:38 +0800 Subject: [PATCH 20/45] chore: update workflow templates to v0.9.10 (#12793) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f7098b730..9a674fac5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.8 +comfyui-workflow-templates==0.9.10 comfyui-embedded-docs==0.4.3 torch torchsde From 1c3b651c0a1539a374e3d29a3ce695b5844ac5fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:35:56 -0800 Subject: [PATCH 21/45] Refactor. (#12794) --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 06aa41d4f..87b36b5c5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,7 +86,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu #a clone conservatively as we are mmapped and some SFT files are packed misaligned #If you are a custom node author reading this, please move your layer to the GPU #or declare your ModelPatcher as CPU in the first place. - if device is not None and device.type == "cpu": + if comfy.model_management.is_device_cpu(device): weight = s.weight.to(dtype=dtype, copy=True) if isinstance(weight, QuantizedTensor): weight = weight.dequantize() From 50549aa252903b936b2ed00b5de418c8b47f0841 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 13:41:06 -0500 Subject: [PATCH 22/45] ComfyUI v0.16.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index e58e0fb63..bc49f2218 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.1" +__version__ = "0.16.2" diff --git a/pyproject.toml b/pyproject.toml index 199a90364..73bfd1007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.1" +version = "0.16.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 8befce5c7b84ff3451a6bd3bcbae1355ad322855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:37:25 +0200 Subject: [PATCH 23/45] Add manual cast to LTX2 vocoder conv_transpose1d (#12795) * Add manual cast to LTX2 vocoder * Update vocoder.py --- comfy/ldm/lightricks/vocoders/vocoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index 6c4028aa8..a0e03cada 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F import torch.nn as nn import comfy.ops +import comfy.model_management import numpy as np import math @@ -125,7 +126,7 @@ class UpSample1d(nn.Module): _, C, _ = x.shape x = F.pad(x, (self.pad, self.pad), mode="replicate") x = self.ratio * F.conv_transpose1d( - x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C ) x = x[..., self.pad_left : -self.pad_right] return x From 17b43c2b87eba43f0f071471b855e0ed659a2627 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:31:28 -0800 Subject: [PATCH 24/45] LTX audio vae novram fixes. (#12796) --- comfy/ldm/lightricks/vocoders/vocoder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index a0e03cada..2481d8bdd 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -82,7 +82,7 @@ class LowPassFilter1d(nn.Module): _, C, _ = x.shape if self.padding: x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) - return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) class UpSample1d(nn.Module): @@ -191,7 +191,7 @@ class Snake(nn.Module): self.eps = 1e-9 def forward(self, x): - a = self.alpha.unsqueeze(0).unsqueeze(-1) + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) if self.alpha_logscale: a = torch.exp(a) return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) @@ -218,8 +218,8 @@ class SnakeBeta(nn.Module): self.eps = 1e-9 def forward(self, x): - a = self.alpha.unsqueeze(0).unsqueeze(-1) - b = self.beta.unsqueeze(0).unsqueeze(-1) + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) if self.alpha_logscale: a = torch.exp(a) b = torch.exp(b) @@ -597,7 +597,7 @@ class _STFTFn(nn.Module): y = y.unsqueeze(1) # (B, 1, T) left_pad = max(0, self.win_length - self.hop_length) # causal: left-only y = F.pad(y, (left_pad, 0)) - spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0) n_freqs = spec.shape[1] // 2 real, imag = spec[:, :n_freqs], spec[:, n_freqs:] magnitude = torch.sqrt(real ** 2 + imag ** 2) @@ -648,7 +648,7 @@ class MelSTFT(nn.Module): """ magnitude, phase = self.stft_fn(y) energy = torch.norm(magnitude, dim=1) - mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude) log_mel = torch.log(torch.clamp(mel, min=1e-5)) return log_mel, magnitude, phase, energy From 58017e8726bdddae89704b1e0123bedc29994424 Mon Sep 17 00:00:00 2001 From: Tavi Halperin Date: Thu, 5 Mar 2026 23:51:20 +0200 Subject: [PATCH 25/45] feat: add causal_fix parameter to add_keyframe_index and append_keyframe (#12797) Allows explicit control over the causal_fix flag passed to latent_to_pixel_coords. Defaults to frame_idx == 0 when not specified, fixing the previous heuristic. --- comfy_extras/nodes_lt.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 32fe921ff..c05571143 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -253,10 +253,12 @@ class LTXVAddGuide(io.ComfyNode): return frame_idx, latent_idx @classmethod - def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1): + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) - pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 + if causal_fix is None: + causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1 + pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix) pixel_coords[:, 0] += frame_idx # The following adjusts keyframe end positions for small grid IC-LoRA. @@ -278,12 +280,12 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1): + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None): if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: raise ValueError("Adding guide to a combined AV latent is not supported.") - positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) - negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) if guide_mask is not None: target_h = max(noise_mask.shape[3], guide_mask.shape[3]) From 1c218282369a6cc80651d878fc51fa33d7bf34e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Mar 2026 17:25:49 -0500 Subject: [PATCH 26/45] ComfyUI v0.16.3 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index bc49f2218..5da21150b 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.2" +__version__ = "0.16.3" diff --git a/pyproject.toml b/pyproject.toml index 73bfd1007..6a83c5c63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.2" +version = "0.16.3" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From e544c65db91df5a070be69a0a9b922201fe79335 Mon Sep 17 00:00:00 2001 From: Dante Date: Fri, 6 Mar 2026 11:51:28 +0900 Subject: [PATCH 27/45] feat: add Math Expression node with simpleeval evaluation (#12687) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add EagerEval dataclass for frontend-side node evaluation Add EagerEval to the V3 API schema, enabling nodes to declare frontend-evaluated JSONata expressions. The frontend uses this to display computation results as badges without a backend round-trip. Co-Authored-By: Claude Opus 4.6 * feat: add Math Expression node with JSONata evaluation Add ComfyMathExpression node that evaluates JSONata expressions against dynamically-grown numeric inputs using Autogrow + MatchType. Sends input context via ui output so the frontend can re-evaluate when the expression changes without a backend round-trip. Co-Authored-By: Claude Opus 4.6 * feat: register nodes_math.py in extras_files loader list Co-Authored-By: Claude Opus 4.6 * fix: address CodeRabbit review feedback - Harden EagerEval.validate with type checks and strip() for empty strings - Add _positional_alias for spreadsheet-style names beyond z (aa, ab...) - Validate JSONata result is numeric before returning - Add jsonata to requirements.txt Co-Authored-By: Claude Opus 4.6 * refactor: remove EagerEval, scope PR to math node only Remove EagerEval dataclass from _io.py and eager_eval usage from nodes_math.py. Eager execution will be designed as a general-purpose system in a separate effort. Co-Authored-By: Claude Opus 4.6 * fix: use TemplateNames, cap inputs at 26, improve error message Address Kosinkadink review feedback: - Switch from Autogrow.TemplatePrefix to Autogrow.TemplateNames so input slots are named a-z, matching expression variables directly - Cap max inputs at 26 (a-z) instead of 100 - Simplify execute() by removing dual-mapping hack - Include expression and result value in error message Co-Authored-By: Claude Opus 4.6 * test: add unit tests for Math Expression node Add tests for _positional_alias (a-z mapping) and execute() covering arithmetic operations, float inputs, $sum(values), and error cases. Co-Authored-By: Claude Opus 4.6 * refactor: replace jsonata with simpleeval for math evaluation jsonata PyPI package has critical issues: no Python 3.12/3.13 wheels, no ARM/Apple Silicon wheels, abandoned (last commit 2023), C extension. Replace with simpleeval (pure Python, 3.4M downloads/month, MIT, AST-based security). Add math module functions (sqrt, ceil, floor, log, sin, cos, tan) and variadic sum() supporting both sum(values) and sum(a, b, c). Pin version to >=1.0,<2.0. Co-Authored-By: Claude Opus 4.6 * test: update tests for simpleeval migration Update JSONata syntax to Python syntax ($sum -> sum, $string -> str), add tests for math functions (sqrt, ceil, floor, sin, log10) and variadic sum(a, b, c). Co-Authored-By: Claude Opus 4.6 * refactor: replace MatchType with MultiType inputs and dual FLOAT/INT outputs Allow mixing INT and FLOAT connections on the same node by switching from MatchType (which forces all inputs to the same type) to MultiType. Output both FLOAT and INT so users can pick the type they need. Co-Authored-By: Claude Opus 4.6 * test: update tests for mixed INT/FLOAT inputs and dual outputs Add assertions for both FLOAT (result[0]) and INT (result[1]) outputs. Add test_mixed_int_float_inputs and test_mixed_resolution_scale to verify the primary use case of multiplying resolutions by a float factor. Co-Authored-By: Claude Opus 4.6 * feat: make expression input multiline and validate empty expression - Add multiline=True to expression input for better UX with longer expressions - Add empty expression validation with clear "Expression cannot be empty." message Co-Authored-By: Claude Opus 4.6 * test: add tests for empty expression validation Co-Authored-By: Claude Opus 4.6 * fix: address review feedback — safe pow, isfinite guard, test coverage - Wrap pow() with _safe_pow to prevent DoS via huge exponents (pow() bypasses simpleeval's safe_power guard on **) - Add math.isfinite() check to catch inf/nan before int() conversion - Add int/float converters to MATH_FUNCTIONS for explicit casting - Add "calculator" search alias - Replace _positional_alias helper with string.ascii_lowercase - Narrow test assertions and add error path + function coverage tests Co-Authored-By: Claude Opus 4.6 * Update requirements.txt --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Jedrzej Kosinski Co-authored-by: Christian Byrne --- comfy_extras/nodes_math.py | 119 +++++++++++ nodes.py | 1 + requirements.txt | 1 + .../comfy_extras_test/nodes_math_test.py | 197 ++++++++++++++++++ 4 files changed, 318 insertions(+) create mode 100644 comfy_extras/nodes_math.py create mode 100644 tests-unit/comfy_extras_test/nodes_math_test.py diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py new file mode 100644 index 000000000..6417bacf1 --- /dev/null +++ b/comfy_extras/nodes_math.py @@ -0,0 +1,119 @@ +"""Math expression node using simpleeval for safe evaluation. + +Provides a ComfyMathExpression node that evaluates math expressions +against dynamically-grown numeric inputs. +""" + +from __future__ import annotations + +import math +import string + +from simpleeval import simple_eval +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +MAX_EXPONENT = 4000 + + +def _variadic_sum(*args): + """Support both sum(values) and sum(a, b, c).""" + if len(args) == 1 and hasattr(args[0], "__iter__"): + return sum(args[0]) + return sum(args) + + +def _safe_pow(base, exp): + """Wrap pow() with an exponent cap to prevent DoS via huge exponents. + + The ** operator is already guarded by simpleeval's safe_power, but + pow() as a callable bypasses that guard. + """ + if abs(exp) > MAX_EXPONENT: + raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})") + return pow(base, exp) + + +MATH_FUNCTIONS = { + "sum": _variadic_sum, + "min": min, + "max": max, + "abs": abs, + "round": round, + "pow": _safe_pow, + "sqrt": math.sqrt, + "ceil": math.ceil, + "floor": math.floor, + "log": math.log, + "log2": math.log2, + "log10": math.log10, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "int": int, + "float": float, +} + + +class MathExpressionNode(io.ComfyNode): + """Evaluates a math expression against dynamically-grown inputs.""" + + @classmethod + def define_schema(cls) -> io.Schema: + autogrow = io.Autogrow.TemplateNames( + input=io.MultiType.Input("value", [io.Float, io.Int]), + names=list(string.ascii_lowercase), + min=1, + ) + return io.Schema( + node_id="ComfyMathExpression", + display_name="Math Expression", + category="math", + search_aliases=[ + "expression", "formula", "calculate", "calculator", + "eval", "math", + ], + inputs=[ + io.String.Input("expression", default="a + b", multiline=True), + io.Autogrow.Input("values", template=autogrow), + ], + outputs=[ + io.Float.Output(display_name="FLOAT"), + io.Int.Output(display_name="INT"), + ], + ) + + @classmethod + def execute( + cls, expression: str, values: io.Autogrow.Type + ) -> io.NodeOutput: + if not expression.strip(): + raise ValueError("Expression cannot be empty.") + + context: dict = dict(values) + context["values"] = list(values.values()) + + result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS) + # bool check must come first because bool is a subclass of int in Python + if isinstance(result, bool) or not isinstance(result, (int, float)): + raise ValueError( + f"Math Expression '{expression}' must evaluate to a numeric result, " + f"got {type(result).__name__}: {result!r}" + ) + if not math.isfinite(result): + raise ValueError( + f"Math Expression '{expression}' produced a non-finite result: {result}" + ) + return io.NodeOutput(float(result), int(result)) + + +class MathExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [MathExpressionNode] + + +async def comfy_entrypoint() -> MathExtension: + return MathExtension() diff --git a/nodes.py b/nodes.py index 5be9b16f9..0ef23b640 100644 --- a/nodes.py +++ b/nodes.py @@ -2449,6 +2449,7 @@ async def init_builtin_extra_nodes(): "nodes_replacements.py", "nodes_nag.py", "nodes_sdpose.py", + "nodes_math.py", ] import_failed = [] diff --git a/requirements.txt b/requirements.txt index 9a674fac5..7bf12247c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,7 @@ av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests +simpleeval>=1.0 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/comfy_extras_test/nodes_math_test.py b/tests-unit/comfy_extras_test/nodes_math_test.py new file mode 100644 index 000000000..fa4cdcac3 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_math_test.py @@ -0,0 +1,197 @@ +import math + +import pytest +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 +mock_server = MagicMock() + +with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}): + from comfy_extras.nodes_math import MathExpressionNode + + +class TestMathExpressionExecute: + @staticmethod + def _exec(expression: str, **kwargs) -> object: + values = OrderedDict(kwargs) + return MathExpressionNode.execute(expression, values) + + def test_addition(self): + result = self._exec("a + b", a=3, b=4) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_subtraction(self): + result = self._exec("a - b", a=10, b=3) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_multiplication(self): + result = self._exec("a * b", a=3, b=5) + assert result[0] == 15.0 + assert result[1] == 15 + + def test_division(self): + result = self._exec("a / b", a=10, b=4) + assert result[0] == 2.5 + assert result[1] == 2 + + def test_single_input(self): + result = self._exec("a * 2", a=5) + assert result[0] == 10.0 + assert result[1] == 10 + + def test_three_inputs(self): + result = self._exec("a + b + c", a=1, b=2, c=3) + assert result[0] == 6.0 + assert result[1] == 6 + + def test_float_inputs(self): + result = self._exec("a + b", a=1.5, b=2.5) + assert result[0] == 4.0 + assert result[1] == 4 + + def test_mixed_int_float_inputs(self): + result = self._exec("a * b", a=1024, b=1.5) + assert result[0] == 1536.0 + assert result[1] == 1536 + + def test_mixed_resolution_scale(self): + result = self._exec("a * b", a=512, b=0.75) + assert result[0] == 384.0 + assert result[1] == 384 + + def test_sum_values_array(self): + result = self._exec("sum(values)", a=1, b=2, c=3) + assert result[0] == 6.0 + + def test_sum_variadic(self): + result = self._exec("sum(a, b, c)", a=1, b=2, c=3) + assert result[0] == 6.0 + + def test_min_values(self): + result = self._exec("min(values)", a=5, b=2, c=8) + assert result[0] == 2.0 + + def test_max_values(self): + result = self._exec("max(values)", a=5, b=2, c=8) + assert result[0] == 8.0 + + def test_abs_function(self): + result = self._exec("abs(a)", a=-7) + assert result[0] == 7.0 + assert result[1] == 7 + + def test_sqrt(self): + result = self._exec("sqrt(a)", a=16) + assert result[0] == 4.0 + assert result[1] == 4 + + def test_ceil(self): + result = self._exec("ceil(a)", a=2.3) + assert result[0] == 3.0 + assert result[1] == 3 + + def test_floor(self): + result = self._exec("floor(a)", a=2.7) + assert result[0] == 2.0 + assert result[1] == 2 + + def test_sin(self): + result = self._exec("sin(a)", a=0) + assert result[0] == 0.0 + + def test_log10(self): + result = self._exec("log10(a)", a=100) + assert result[0] == 2.0 + assert result[1] == 2 + + def test_float_output_type(self): + result = self._exec("a + b", a=1, b=2) + assert isinstance(result[0], float) + + def test_int_output_type(self): + result = self._exec("a + b", a=1, b=2) + assert isinstance(result[1], int) + + def test_non_numeric_result_raises(self): + with pytest.raises(ValueError, match="must evaluate to a numeric result"): + self._exec("'hello'", a=42) + + def test_undefined_function_raises(self): + with pytest.raises(Exception, match="not defined"): + self._exec("str(a)", a=42) + + def test_boolean_result_raises(self): + with pytest.raises(ValueError, match="got bool"): + self._exec("a > b", a=5, b=3) + + def test_empty_expression_raises(self): + with pytest.raises(ValueError, match="Expression cannot be empty"): + self._exec("", a=1) + + def test_whitespace_only_expression_raises(self): + with pytest.raises(ValueError, match="Expression cannot be empty"): + self._exec(" ", a=1) + + # --- Missing function coverage (round, pow, log, log2, cos, tan) --- + + def test_round(self): + result = self._exec("round(a)", a=2.7) + assert result[0] == 3.0 + assert result[1] == 3 + + def test_round_with_ndigits(self): + result = self._exec("round(a, 2)", a=3.14159) + assert result[0] == pytest.approx(3.14) + + def test_pow(self): + result = self._exec("pow(a, b)", a=2, b=10) + assert result[0] == 1024.0 + assert result[1] == 1024 + + def test_log(self): + result = self._exec("log(a)", a=math.e) + assert result[0] == pytest.approx(1.0) + + def test_log2(self): + result = self._exec("log2(a)", a=8) + assert result[0] == pytest.approx(3.0) + + def test_cos(self): + result = self._exec("cos(a)", a=0) + assert result[0] == 1.0 + + def test_tan(self): + result = self._exec("tan(a)", a=0) + assert result[0] == 0.0 + + # --- int/float converter functions --- + + def test_int_converter(self): + result = self._exec("int(a / b)", a=7, b=2) + assert result[1] == 3 + + def test_float_converter(self): + result = self._exec("float(a)", a=5) + assert result[0] == 5.0 + + # --- Error path tests --- + + def test_division_by_zero_raises(self): + with pytest.raises(ZeroDivisionError): + self._exec("a / b", a=1, b=0) + + def test_sqrt_negative_raises(self): + with pytest.raises(ValueError, match="math domain error"): + self._exec("sqrt(a)", a=-1) + + def test_overflow_inf_raises(self): + with pytest.raises(ValueError, match="non-finite result"): + self._exec("a * b", a=1e308, b=10) + + def test_pow_huge_exponent_raises(self): + with pytest.raises(ValueError, match="Exponent .* exceeds maximum"): + self._exec("pow(a, b)", a=10, b=10000000) From 3b93d5d571cb3e018da65f822cd11b60202b11c2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:04:48 +0200 Subject: [PATCH 28/45] feat(api-nodes): add TencentSmartTopology node (#12741) * feat(api-nodes): add TencentSmartTopology node * feat(api-nodes): enable TencentModelTo3DUV node * chore(Tencent endpoints): add "wait" to queued statuses --- comfy_api_nodes/apis/hunyuan3d.py | 16 ++++- comfy_api_nodes/nodes_hunyuan3d.py | 109 ++++++++++++++++++++++++++--- comfy_api_nodes/util/client.py | 2 +- 3 files changed, 114 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py index e84eba31e..dad9bc2fa 100644 --- a/comfy_api_nodes/apis/hunyuan3d.py +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -66,13 +66,17 @@ class To3DProTaskQueryRequest(BaseModel): JobId: str = Field(...) -class To3DUVFileInput(BaseModel): +class TaskFile3DInput(BaseModel): Type: str = Field(..., description="File type: GLB, OBJ, or FBX") Url: str = Field(...) class To3DUVTaskRequest(BaseModel): - File: To3DUVFileInput = Field(...) + File: TaskFile3DInput = Field(...) + + +class To3DPartTaskRequest(BaseModel): + File: TaskFile3DInput = Field(...) class TextureEditImageInfo(BaseModel): @@ -80,7 +84,13 @@ class TextureEditImageInfo(BaseModel): class TextureEditTaskRequest(BaseModel): - File3D: To3DUVFileInput = Field(...) + File3D: TaskFile3DInput = Field(...) Image: TextureEditImageInfo | None = Field(None) Prompt: str | None = Field(None) EnablePBR: bool | None = Field(None) + + +class SmartTopologyRequest(BaseModel): + File3D: TaskFile3DInput = Field(...) + PolygonType: str | None = Field(...) + FaceLevel: str | None = Field(...) diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index d1d9578ec..bd8bde997 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -5,18 +5,19 @@ from comfy_api_nodes.apis.hunyuan3d import ( Hunyuan3DViewImage, InputGenerateType, ResultFile3D, + SmartTopologyRequest, + TaskFile3DInput, TextureEditTaskRequest, + To3DPartTaskRequest, To3DProTaskCreateResponse, To3DProTaskQueryRequest, To3DProTaskRequest, To3DProTaskResultResponse, - To3DUVFileInput, To3DUVTaskRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_file_3d, - download_url_to_image_tensor, downscale_image_tensor_by_max_side, poll_op, sync_op, @@ -344,7 +345,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode): outputs=[ IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DFBX.Output(display_name="FBX"), - IO.Image.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -375,7 +375,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode): ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"), response_model=To3DProTaskCreateResponse, data=To3DUVTaskRequest( - File=To3DUVFileInput( + File=TaskFile3DInput( Type=file_format.upper(), Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format), ) @@ -394,7 +394,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode): return IO.NodeOutput( await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), - await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url), ) @@ -463,7 +462,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode): ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"), response_model=To3DProTaskCreateResponse, data=TextureEditTaskRequest( - File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), Prompt=prompt, EnablePBR=True, ), @@ -538,8 +537,8 @@ class Tencent3DPartNode(IO.ComfyNode): cls, ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"), response_model=To3DProTaskCreateResponse, - data=To3DUVTaskRequest( - File=To3DUVFileInput(Type=file_format.upper(), Url=model_url), + data=To3DPartTaskRequest( + File=TaskFile3DInput(Type=file_format.upper(), Url=model_url), ), is_rate_limited=_is_tencent_rate_limited, ) @@ -557,15 +556,107 @@ class Tencent3DPartNode(IO.ComfyNode): ) +class TencentSmartTopologyNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentSmartTopologyNode", + display_name="Hunyuan3D: Smart Topology", + category="api node/3d/Tencent", + description="Perform smart retopology on a 3D model. " + "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny], + tooltip="Input 3D model (GLB or OBJ)", + ), + IO.Combo.Input( + "polygon_type", + options=["triangle", "quadrilateral"], + tooltip="Surface composition type.", + ), + IO.Combo.Input( + "face_level", + options=["medium", "high", "low"], + tooltip="Polygon reduction level.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DOBJ.Output(display_name="OBJ"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'), + ) + + SUPPORTED_FORMATS = {"glb", "obj"} + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + polygon_type: str, + face_level: str, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format not in cls.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." + ) + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"), + response_model=To3DProTaskCreateResponse, + data=SmartTopologyRequest( + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), + PolygonType=polygon_type, + FaceLevel=face_level, + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + ) + + class TencentHunyuan3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TencentTextToModelNode, TencentImageToModelNode, - # TencentModelTo3DUVNode, + TencentModelTo3DUVNode, # Tencent3DTextureEditNode, Tencent3DPartNode, + TencentSmartTopologyNode, ] diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 94886af7b..79ffb77c1 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -83,7 +83,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] -QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] async def sync_op( From 34e55f006156801a6b5988d046d9041cb681f12d Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 6 Mar 2026 19:54:27 +0200 Subject: [PATCH 29/45] feat(api-nodes): add Gemini 3.1 Flash Lite model to LLM node (#12803) --- comfy_api_nodes/nodes_gemini.py | 47 +++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index d83d2fc15..8225ea67e 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -72,18 +72,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( ) -class GeminiModel(str, Enum): - """ - Gemini Model Names allowed by comfy-api - """ - - gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" - gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" - gemini_2_5_pro = "gemini-2.5-pro" - gemini_2_5_flash = "gemini-2.5-flash" - gemini_3_0_pro = "gemini-3-pro-preview" - - class GeminiImageModel(str, Enum): """ Gemini Image Model Names allowed by comfy-api @@ -237,10 +225,14 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 30.0 - elif response.modelVersion == "gemini-3-pro-preview": + elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"): input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3.1-flash-lite-preview": + input_tokens_price = 0.25 + output_text_tokens_price = 1.50 + output_image_tokens_price = 0.0 elif response.modelVersion == "gemini-3-pro-image-preview": input_tokens_price = 2 output_text_tokens_price = 12.0 @@ -292,8 +284,16 @@ class GeminiNode(IO.ComfyNode): ), IO.Combo.Input( "model", - options=GeminiModel, - default=GeminiModel.gemini_2_5_pro, + options=[ + "gemini-2.5-pro-preview-05-06", + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-3-pro-preview", + "gemini-3-1-pro", + "gemini-3-1-flash-lite", + ], + default="gemini-3-1-pro", tooltip="The Gemini model to use for generating responses.", ), IO.Int.Input( @@ -363,11 +363,16 @@ class GeminiNode(IO.ComfyNode): "usd": [0.00125, 0.01], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } - : $contains($m, "gemini-3-pro-preview") ? { + : ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? { "type": "list_usd", "usd": [0.002, 0.012], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } + : $contains($m, "gemini-3-1-flash-lite") ? { + "type": "list_usd", + "usd": [0.00025, 0.0015], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : {"type":"text", "text":"Token-based"} ) """, @@ -436,12 +441,14 @@ class GeminiNode(IO.ComfyNode): files: list[GeminiPart] | None = None, system_prompt: str = "", ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False) + if model == "gemini-3-pro-preview": + model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google + elif model == "gemini-3-1-pro": + model = "gemini-3.1-pro-preview" + elif model == "gemini-3-1-flash-lite": + model = "gemini-3.1-flash-lite-preview" - # Create parts list with text prompt as the first part parts: list[GeminiPart] = [GeminiPart(text=prompt)] - - # Add other modal parts if images is not None: parts.extend(await create_image_parts(cls, images)) if audio is not None: From f466b066017b9ebe5df67decfcbd09f78c5c66fa Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:20:07 -0800 Subject: [PATCH 30/45] Fix fp16 audio encoder models (#12811) * mp: respect model_defined_dtypes in default caster This is needed for parametrizations when the dtype changes between sd and model. * audio_encoders: archive model dtypes Archive model dtypes to stop the state dict load override the dtypes defined by the core for compute etc. --- comfy/audio_encoders/audio_encoders.py | 1 + comfy/model_patcher.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 16998af94..0de7584b0 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -27,6 +27,7 @@ class AudioEncoderModel(): self.model.eval() self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 + comfy.model_management.archive_model_dtypes(self.model) def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7e5ad7aa4..745384271 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -715,8 +715,8 @@ class ModelPatcher: default = True # default random weights in non leaf modules break if default and default_device is not None: - for param in params.values(): - param.data = param.data.to(device=default_device) + for param_name, param in params.items(): + param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None)) if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem From d69d30819b91aa020d0bb888df2a5b917f83bb7e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:11:16 -0800 Subject: [PATCH 31/45] Don't run TE on cpu when dynamic vram enabled. (#12815) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index ee28ea107..39b4aa483 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -939,7 +939,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: From afc00f00553885eeb96ded329878fe732f6b9f7a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:10:53 -0800 Subject: [PATCH 32/45] Fix requirements version. (#12817) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7bf12247c..26e2ecdec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests -simpleeval>=1.0 +simpleeval>=1.0.0 #non essential dependencies: kornia>=0.7.1 From 6ac8152fc80734b084d12865460e5e9a5d9a4e1b Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Sat, 7 Mar 2026 15:54:09 +0800 Subject: [PATCH 33/45] chore: update workflow templates to v0.9.11 (#12821) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 26e2ecdec..dc9a9ded0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.10 +comfyui-workflow-templates==0.9.11 comfyui-embedded-docs==0.4.3 torch torchsde From bcf1a1fab1e9efe0d4999ea14e9c0318409e0000 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 7 Mar 2026 09:38:08 -0800 Subject: [PATCH 34/45] mm: reset_cast_buffers: sync compute stream before free (#12822) Sync the compute stream before freeing the cast buffers. This can cause use after free issues when the cast stream frees the buffer while the compute stream is behind enough to still needs a casted weight. --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 39b4aa483..07bc8ad67 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1148,6 +1148,7 @@ def reset_cast_buffers(): LARGEST_CASTED_WEIGHT = (None, 0) for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() + synchronize() STREAM_CAST_BUFFERS.clear() soft_empty_cache() From a7a6335be538f55faa2abf7404c9b8e970847d1f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 7 Mar 2026 16:52:39 -0500 Subject: [PATCH 35/45] ComfyUI v0.16.4 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5da21150b..2723d02e7 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.3" +__version__ = "0.16.4" diff --git a/pyproject.toml b/pyproject.toml index 6a83c5c63..753b219b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.3" +version = "0.16.4" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 29b24cb5177e9d5aa5b3d2e5869999efb4d538c7 Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Sat, 7 Mar 2026 17:37:25 -0800 Subject: [PATCH 36/45] refactor(assets): modular architecture + async two-phase scanner & background seeder (#12621) --- .../0002_merge_to_asset_references.py | 267 +++++ app/assets/api/routes.py | 810 ++++++++----- app/assets/api/schemas_in.py | 79 +- app/assets/api/schemas_out.py | 6 +- app/assets/api/upload.py | 171 +++ app/assets/database/bulk_ops.py | 204 ---- app/assets/database/models.py | 210 ++-- app/assets/database/queries.py | 976 ---------------- app/assets/database/queries/__init__.py | 121 ++ app/assets/database/queries/asset.py | 140 +++ .../database/queries/asset_reference.py | 1033 +++++++++++++++++ app/assets/database/queries/common.py | 54 + app/assets/database/queries/tags.py | 356 ++++++ app/assets/database/tags.py | 62 - app/assets/hashing.py | 75 -- app/assets/helpers.py | 319 +---- app/assets/manager.py | 516 -------- app/assets/scanner.py | 768 ++++++++---- app/assets/seeder.py | 794 +++++++++++++ app/assets/services/__init__.py | 87 ++ app/assets/services/asset_management.py | 309 +++++ app/assets/services/bulk_ingest.py | 280 +++++ app/assets/services/file_utils.py | 70 ++ app/assets/services/hashing.py | 95 ++ app/assets/services/ingest.py | 375 ++++++ app/assets/services/metadata_extract.py | 327 ++++++ app/assets/services/path_utils.py | 167 +++ app/assets/services/schemas.py | 109 ++ app/assets/services/tagging.py | 75 ++ app/database/db.py | 81 +- comfy/cli_args.py | 2 +- comfy_api/feature_flags.py | 1 + main.py | 35 +- requirements.txt | 2 + server.py | 19 +- tests-unit/assets_test/conftest.py | 16 +- tests-unit/assets_test/helpers.py | 28 + tests-unit/assets_test/queries/conftest.py | 20 + tests-unit/assets_test/queries/test_asset.py | 144 +++ .../assets_test/queries/test_asset_info.py | 517 +++++++++ .../assets_test/queries/test_cache_state.py | 499 ++++++++ .../assets_test/queries/test_metadata.py | 184 +++ tests-unit/assets_test/queries/test_tags.py | 366 ++++++ tests-unit/assets_test/services/__init__.py | 1 + tests-unit/assets_test/services/conftest.py | 54 + .../services/test_asset_management.py | 268 +++++ .../assets_test/services/test_bulk_ingest.py | 137 +++ .../assets_test/services/test_enrich.py | 207 ++++ .../assets_test/services/test_ingest.py | 229 ++++ .../assets_test/services/test_tagging.py | 197 ++++ .../assets_test/test_assets_missing_sync.py | 2 +- tests-unit/assets_test/test_crud.py | 138 ++- tests-unit/assets_test/test_downloads.py | 4 +- tests-unit/assets_test/test_file_utils.py | 121 ++ tests-unit/assets_test/test_list_filter.py | 40 +- .../assets_test/test_prune_orphaned_assets.py | 2 +- .../assets_test/test_sync_references.py | 482 ++++++++ .../{test_tags.py => test_tags_api.py} | 4 +- tests-unit/assets_test/test_uploads.py | 22 +- tests-unit/requirements.txt | 1 - tests-unit/seeder_test/test_seeder.py | 900 ++++++++++++++ utils/mime_types.py | 37 + 62 files changed, 10737 insertions(+), 2878 deletions(-) create mode 100644 alembic_db/versions/0002_merge_to_asset_references.py create mode 100644 app/assets/api/upload.py delete mode 100644 app/assets/database/bulk_ops.py delete mode 100644 app/assets/database/queries.py create mode 100644 app/assets/database/queries/__init__.py create mode 100644 app/assets/database/queries/asset.py create mode 100644 app/assets/database/queries/asset_reference.py create mode 100644 app/assets/database/queries/common.py create mode 100644 app/assets/database/queries/tags.py delete mode 100644 app/assets/database/tags.py delete mode 100644 app/assets/hashing.py delete mode 100644 app/assets/manager.py create mode 100644 app/assets/seeder.py create mode 100644 app/assets/services/__init__.py create mode 100644 app/assets/services/asset_management.py create mode 100644 app/assets/services/bulk_ingest.py create mode 100644 app/assets/services/file_utils.py create mode 100644 app/assets/services/hashing.py create mode 100644 app/assets/services/ingest.py create mode 100644 app/assets/services/metadata_extract.py create mode 100644 app/assets/services/path_utils.py create mode 100644 app/assets/services/schemas.py create mode 100644 app/assets/services/tagging.py create mode 100644 tests-unit/assets_test/helpers.py create mode 100644 tests-unit/assets_test/queries/conftest.py create mode 100644 tests-unit/assets_test/queries/test_asset.py create mode 100644 tests-unit/assets_test/queries/test_asset_info.py create mode 100644 tests-unit/assets_test/queries/test_cache_state.py create mode 100644 tests-unit/assets_test/queries/test_metadata.py create mode 100644 tests-unit/assets_test/queries/test_tags.py create mode 100644 tests-unit/assets_test/services/__init__.py create mode 100644 tests-unit/assets_test/services/conftest.py create mode 100644 tests-unit/assets_test/services/test_asset_management.py create mode 100644 tests-unit/assets_test/services/test_bulk_ingest.py create mode 100644 tests-unit/assets_test/services/test_enrich.py create mode 100644 tests-unit/assets_test/services/test_ingest.py create mode 100644 tests-unit/assets_test/services/test_tagging.py create mode 100644 tests-unit/assets_test/test_file_utils.py create mode 100644 tests-unit/assets_test/test_sync_references.py rename tests-unit/assets_test/{test_tags.py => test_tags_api.py} (98%) create mode 100644 tests-unit/seeder_test/test_seeder.py create mode 100644 utils/mime_types.py diff --git a/alembic_db/versions/0002_merge_to_asset_references.py b/alembic_db/versions/0002_merge_to_asset_references.py new file mode 100644 index 000000000..1ac1b980c --- /dev/null +++ b/alembic_db/versions/0002_merge_to_asset_references.py @@ -0,0 +1,267 @@ +""" +Merge AssetInfo and AssetCacheState into unified asset_references table. + +This migration drops old tables and creates the new unified schema. +All existing data is discarded. + +Revision ID: 0002_merge_to_asset_references +Revises: 0001_assets +Create Date: 2025-02-11 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0002_merge_to_asset_references" +down_revision = "0001_assets" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Drop old tables (order matters due to FK constraints) + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + # Truncate assets table (cascades handled by dropping dependent tables first) + op.execute("DELETE FROM assets") + + # Create asset_references table + op.create_table( + "asset_references", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "asset_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("file_path", sa.Text(), nullable=True), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column( + "needs_verify", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false") + ), + sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column( + "preview_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True), + sa.CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + sa.CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), + ) + op.create_index( + "uq_asset_references_file_path", "asset_references", ["file_path"], unique=True + ) + op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"]) + op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"]) + op.create_index("ix_asset_references_name", "asset_references", ["name"]) + op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"]) + op.create_index( + "ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"] + ) + op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"]) + op.create_index( + "ix_asset_references_last_access_time", "asset_references", ["last_access_time"] + ) + op.create_index( + "ix_asset_references_owner_name", "asset_references", ["owner_id", "name"] + ) + op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"]) + + # Create asset_reference_tags table + op.create_table( + "asset_reference_tags", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "tag_name", + sa.String(length=512), + sa.ForeignKey("tags.name", ondelete="RESTRICT"), + nullable=False, + ), + sa.Column( + "origin", sa.String(length=32), nullable=False, server_default="manual" + ), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint( + "asset_reference_id", "tag_name", name="pk_asset_reference_tags" + ), + ) + op.create_index( + "ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"] + ) + op.create_index( + "ix_asset_reference_tags_asset_reference_id", + "asset_reference_tags", + ["asset_reference_id"], + ) + + # Create asset_reference_meta table + op.create_table( + "asset_reference_meta", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint( + "asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta" + ), + ) + op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"]) + op.create_index( + "ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_bool", + "asset_reference_meta", + ["key", "val_bool"], + ) + + +def downgrade() -> None: + """Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema. + + NOTE: Data is not recoverable. The upgrade discards all rows from the old + tables and truncates assets. After downgrade the old schema will be empty. + A filesystem rescan will repopulate data once the older code is running. + """ + # Drop new tables (order matters due to FK constraints) + op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta") + op.drop_table("asset_reference_meta") + + op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags") + op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags") + op.drop_table("asset_reference_tags") + + op.drop_index("ix_asset_references_deleted_at", table_name="asset_references") + op.drop_index("ix_asset_references_owner_name", table_name="asset_references") + op.drop_index("ix_asset_references_last_access_time", table_name="asset_references") + op.drop_index("ix_asset_references_created_at", table_name="asset_references") + op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references") + op.drop_index("ix_asset_references_is_missing", table_name="asset_references") + op.drop_index("ix_asset_references_name", table_name="asset_references") + op.drop_index("ix_asset_references_owner_id", table_name="asset_references") + op.drop_index("ix_asset_references_asset_id", table_name="asset_references") + op.drop_index("uq_asset_references_file_path", table_name="asset_references") + op.drop_table("asset_references") + + # Truncate assets (upgrade deleted all rows; downgrade starts fresh too) + op.execute("DELETE FROM assets") + + # Recreate old tables from 0001_assets schema + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 7676e50b4..40dee9f46 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -1,56 +1,144 @@ +import asyncio +import functools +import json import logging -import uuid -import urllib.parse import os -import contextlib -from aiohttp import web +import urllib.parse +import uuid +from typing import Any +from aiohttp import web from pydantic import ValidationError -import app.assets.manager as manager -from app import user_manager -from app.assets.api import schemas_in -from app.assets.helpers import get_query_dict -from app.assets.scanner import seed_assets - import folder_paths +from app import user_manager +from app.assets.api import schemas_in, schemas_out +from app.assets.api.schemas_in import ( + AssetValidationError, + UploadError, +) +from app.assets.helpers import validate_blake3_hash +from app.assets.api.upload import ( + delete_temp_file_if_exists, + parse_multipart_upload, +) +from app.assets.seeder import ScanInProgressError, asset_seeder +from app.assets.services import ( + DependencyMissingError, + HashMismatchError, + apply_tags, + asset_exists, + create_from_hash, + delete_asset_reference, + get_asset_detail, + list_assets_page, + list_tags, + remove_tags, + resolve_asset_for_download, + update_asset_metadata, + upload_from_temp_path, +) ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None +_ASSETS_ENABLED = False + + +def _require_assets_feature_enabled(handler): + @functools.wraps(handler) + async def wrapper(request: web.Request) -> web.Response: + if not _ASSETS_ENABLED: + return _build_error_response( + 503, + "SERVICE_DISABLED", + "Assets system is disabled. Start the server with --enable-assets to use this feature.", + ) + return await handler(request) + + return wrapper + # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" -# Note to any custom node developers reading this code: -# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same. -def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: - global USER_MANAGER - USER_MANAGER = user_manager_instance +def get_query_dict(request: web.Request) -> dict[str, Any]: + """Gets a dictionary of query parameters from the request. + + request.query is a MultiMapping[str], needs to be converted to a dict + to be validated by Pydantic. + """ + query_dict = { + key: request.query.getall(key) + if len(request.query.getall(key)) > 1 + else request.query.get(key) + for key in request.query.keys() + } + return query_dict + + +# Note to any custom node developers reading this code: +# The assets system is not yet fully implemented, +# do not rely on the code in /app/assets remaining the same. + + +def register_assets_routes( + app: web.Application, + user_manager_instance: user_manager.UserManager | None = None, +) -> None: + global USER_MANAGER, _ASSETS_ENABLED + if user_manager_instance is not None: + USER_MANAGER = user_manager_instance + _ASSETS_ENABLED = True app.add_routes(ROUTES) -def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: - return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + +def disable_assets_routes() -> None: + """Disable asset routes at runtime (e.g. after DB init failure).""" + global _ASSETS_ENABLED + _ASSETS_ENABLED = False -def _validation_error_response(code: str, ve: ValidationError) -> web.Response: - return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) +def _build_error_response( + status: int, code: str, message: str, details: dict | None = None +) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response: + errors = json.loads(ve.json()) + return _build_error_response(400, code, "Validation failed.", {"errors": errors}) + + +def _validate_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" @ROUTES.head("/api/assets/hash/{hash}") +@_require_assets_feature_enabled async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() - if not hash_str or ":" not in hash_str: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = hash_str.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - exists = manager.asset_exists(asset_hash=hash_str) + try: + hash_str = validate_blake3_hash(hash_str) + except ValueError: + return _build_error_response( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + exists = asset_exists(hash_str) return web.Response(status=200 if exists else 404) @ROUTES.get("/api/assets") -async def list_assets(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def list_assets_route(request: web.Request) -> web.Response: """ GET request to list assets. """ @@ -58,78 +146,140 @@ async def list_assets(request: web.Request) -> web.Response: try: q = schemas_in.ListAssetsQuery.model_validate(query_dict) except ValidationError as ve: - return _validation_error_response("INVALID_QUERY", ve) + return _build_validation_error_response("INVALID_QUERY", ve) - payload = manager.list_assets( + sort = _validate_sort_field(q.sort) + order_candidate = (q.order or "desc").lower() + order = order_candidate if order_candidate in {"asc", "desc"} else "desc" + + result = list_assets_page( + owner_id=USER_MANAGER.get_request_user_id(request), include_tags=q.include_tags, exclude_tags=q.exclude_tags, name_contains=q.name_contains, metadata_filter=q.metadata_filter, limit=q.limit, offset=q.offset, - sort=q.sort, - order=q.order, - owner_id=USER_MANAGER.get_request_user_id(request), + sort=sort, + order=order, + ) + + summaries = [ + schemas_out.AssetSummary( + id=item.ref.id, + name=item.ref.name, + asset_hash=item.asset.hash if item.asset else None, + size=int(item.asset.size_bytes) if item.asset else None, + mime_type=item.asset.mime_type if item.asset else None, + tags=item.tags, + created_at=item.ref.created_at, + updated_at=item.ref.updated_at, + last_access_time=item.ref.last_access_time, + ) + for item in result.items + ] + + payload = schemas_out.AssetsList( + assets=summaries, + total=result.total, + has_more=(q.offset + len(summaries)) < result.total, ) return web.json_response(payload.model_dump(mode="json", exclude_none=True)) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") -async def get_asset(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def get_asset_route(request: web.Request) -> web.Response: """ GET request to get an asset's info as JSON. """ - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - result = manager.get_asset( - asset_info_id=asset_info_id, + result = get_asset_detail( + reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), ) + if not result: + return _build_error_response( + 404, + "ASSET_NOT_FOUND", + f"AssetReference {reference_id} not found", + {"id": reference_id}, + ) + + payload = schemas_out.AssetDetail( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + size=int(result.asset.size_bytes) if result.asset else None, + mime_type=result.asset.mime_type if result.asset else None, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + ) except ValueError as e: - return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} + ) except Exception: logging.exception( - "get_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "get_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +@_require_assets_feature_enabled async def download_asset_content(request: web.Request) -> web.Response: - # question: do we need disposition? could we just stick with one of these? disposition = request.query.get("disposition", "attachment").lower().strip() if disposition not in {"inline", "attachment"}: disposition = "attachment" try: - abs_path, content_type, filename = manager.resolve_asset_content_for_download( - asset_info_id=str(uuid.UUID(request.match_info["id"])), + result = resolve_asset_for_download( + reference_id=str(uuid.UUID(request.match_info["id"])), owner_id=USER_MANAGER.get_request_user_id(request), ) + abs_path = result.abs_path + content_type = result.content_type + filename = result.download_name except ValueError as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + return _build_error_response(404, "ASSET_NOT_FOUND", str(ve)) except NotImplementedError as nie: - return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie)) except FileNotFoundError: - return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + return _build_error_response( + 404, "FILE_NOT_FOUND", "Underlying file not found on disk." + ) - quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") - cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + _DANGEROUS_MIME_TYPES = { + "text/html", "text/html-sandboxed", "application/xhtml+xml", + "text/javascript", "text/css", + } + if content_type in _DANGEROUS_MIME_TYPES: + content_type = "application/octet-stream" + + safe_name = (filename or "").replace("\r", "").replace("\n", "") + encoded = urllib.parse.quote(safe_name) + cd = f"{disposition}; filename*=UTF-8''{encoded}" file_size = os.path.getsize(abs_path) + size_mb = file_size / (1024 * 1024) logging.info( - "download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s", + "download_asset_content: path=%s, size=%d bytes (%.2f MB), type=%s, name=%s", abs_path, file_size, - file_size / (1024 * 1024), + size_mb, content_type, filename, ) - async def file_sender(): + async def stream_file_chunks(): chunk_size = 64 * 1024 with open(abs_path, "rb") as f: while True: @@ -139,26 +289,30 @@ async def download_asset_content(request: web.Request) -> web.Response: yield chunk return web.Response( - body=file_sender(), + body=stream_file_chunks(), content_type=content_type, headers={ "Content-Disposition": cd, "Content-Length": str(file_size), + "X-Content-Type-Options": "nosniff", }, ) @ROUTES.post("/api/assets/from-hash") -async def create_asset_from_hash(request: web.Request) -> web.Response: +@_require_assets_feature_enabled +async def create_asset_from_hash_route(request: web.Request) -> web.Response: try: payload = await request.json() body = schemas_in.CreateFromHashBody.model_validate(payload) except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) + return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) - result = manager.create_asset_from_hash( + result = create_from_hash( hash_str=body.hash, name=body.name, tags=body.tags, @@ -166,246 +320,209 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") - return web.json_response(result.model_dump(mode="json"), status=201) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" + ) + + payload_out = schemas_out.AssetCreated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash, + size=int(result.asset.size_bytes), + mime_type=result.asset.mime_type, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + created_new=result.created_new, + ) + return web.json_response(payload_out.model_dump(mode="json"), status=201) @ROUTES.post("/api/assets") +@_require_assets_feature_enabled async def upload_asset(request: web.Request) -> web.Response: """Multipart/form-data endpoint for Asset uploads.""" - if not (request.content_type or "").lower().startswith("multipart/"): - return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") - - reader = await request.multipart() - - file_present = False - file_client_name: str | None = None - tags_raw: list[str] = [] - provided_name: str | None = None - user_metadata_raw: str | None = None - provided_hash: str | None = None - provided_hash_exists: bool | None = None - - file_written = 0 - tmp_path: str | None = None - while True: - field = await reader.next() - if field is None: - break - - fname = getattr(field, "name", "") or "" - - if fname == "hash": - try: - s = ((await field.text()) or "").strip().lower() - except Exception: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - - if s: - if ":" not in s: - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") - provided_hash = f"{algo}:{digest}" - try: - provided_hash_exists = manager.asset_exists(asset_hash=provided_hash) - except Exception: - provided_hash_exists = None # do not fail the whole request here - - elif fname == "file": - file_present = True - file_client_name = (field.filename or "").strip() - - if provided_hash and provided_hash_exists is True: - # If client supplied a hash that we know exists, drain but do not write to disk - try: - while True: - chunk = await field.read_chunk(8 * 1024 * 1024) - if not chunk: - break - file_written += len(chunk) - except Exception: - return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") - continue # Do not create temp file; we will create AssetInfo from the existing content - - # Otherwise, store to temp for hashing/ingest - uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") - unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) - os.makedirs(unique_dir, exist_ok=True) - tmp_path = os.path.join(unique_dir, ".upload.part") - - try: - with open(tmp_path, "wb") as f: - while True: - chunk = await field.read_chunk(8 * 1024 * 1024) - if not chunk: - break - f.write(chunk) - file_written += len(chunk) - except Exception: - try: - if os.path.exists(tmp_path or ""): - os.remove(tmp_path) - finally: - return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") - elif fname == "tags": - tags_raw.append((await field.text()) or "") - elif fname == "name": - provided_name = (await field.text()) or None - elif fname == "user_metadata": - user_metadata_raw = (await field.text()) or None - - # If client did not send file, and we are not doing a from-hash fast path -> error - if not file_present and not (provided_hash and provided_hash_exists): - return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") - - if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): - # Empty upload is only acceptable if we are fast-pathing from existing hash - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") - try: - spec = schemas_in.UploadAssetSpec.model_validate({ - "tags": tags_raw, - "name": provided_name, - "user_metadata": user_metadata_raw, - "hash": provided_hash, - }) - except ValidationError as ve: - try: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _validation_error_response("INVALID_BODY", ve) - - # Validate models category against configured folders (consistent with previous behavior) - if spec.tags and spec.tags[0] == "models": - if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - return _error_response( - 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" - ) + parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists) + except UploadError as e: + return _build_error_response(e.status, e.code, e.message) owner_id = USER_MANAGER.get_request_user_id(request) - # Fast path: if a valid provided hash exists, create AssetInfo without writing anything - if spec.hash and provided_hash_exists is True: - try: - result = manager.create_asset_from_hash( + try: + spec = schemas_in.UploadAssetSpec.model_validate( + { + "tags": parsed.tags_raw, + "name": parsed.provided_name, + "user_metadata": parsed.user_metadata_raw, + "hash": parsed.provided_hash, + } + ) + except ValidationError as ve: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 400, "INVALID_BODY", f"Validation failed: {ve.json()}" + ) + + if spec.tags and spec.tags[0] == "models": + if ( + len(spec.tags) < 2 + or spec.tags[1] not in folder_paths.folder_names_and_paths + ): + delete_temp_file_if_exists(parsed.tmp_path) + category = spec.tags[1] if len(spec.tags) >= 2 else "" + return _build_error_response( + 400, "INVALID_BODY", f"unknown models category '{category}'" + ) + + try: + # Fast path: hash exists, create AssetReference without writing anything + if spec.hash and parsed.provided_hash_exists is True: + result = create_from_hash( hash_str=spec.hash, name=spec.name or (spec.hash.split(":", 1)[1]), tags=spec.tags, user_metadata=spec.user_metadata or {}, owner_id=owner_id, ) - except Exception: - logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) - return _error_response(500, "INTERNAL", "Unexpected server error.") + if result is None: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist" + ) + delete_temp_file_if_exists(parsed.tmp_path) + else: + # Otherwise, we must have a temp file path to ingest + if not parsed.tmp_path or not os.path.exists(parsed.tmp_path): + return _build_error_response( + 400, + "MISSING_INPUT", + "Provided hash not found and no file uploaded.", + ) - if result is None: - return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") - - # Drain temp if we accidentally saved (e.g., hash field came after file) - if tmp_path and os.path.exists(tmp_path): - with contextlib.suppress(Exception): - os.remove(tmp_path) - - status = 200 if (not result.created_new) else 201 - return web.json_response(result.model_dump(mode="json"), status=status) - - # Otherwise, we must have a temp file path to ingest - if not tmp_path or not os.path.exists(tmp_path): - # The only case we reach here without a temp file is: client sent a hash that does not exist and no file - return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") - - try: - created = manager.upload_asset_from_temp_path( - spec, - temp_path=tmp_path, - client_filename=file_client_name, - owner_id=owner_id, - expected_asset_hash=spec.hash, - ) - status = 201 if created.created_new else 200 - return web.json_response(created.model_dump(mode="json"), status=status) - except ValueError as e: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - msg = str(e) - if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": - return _error_response( - 400, - "HASH_MISMATCH", - "Uploaded file hash does not match provided hash.", + result = upload_from_temp_path( + temp_path=parsed.tmp_path, + name=spec.name, + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + client_filename=parsed.file_client_name, + owner_id=owner_id, + expected_hash=spec.hash, ) - return _error_response(400, "BAD_REQUEST", "Invalid inputs.") + except AssetValidationError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, e.code, str(e)) + except ValueError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "BAD_REQUEST", str(e)) + except HashMismatchError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "HASH_MISMATCH", str(e)) + except DependencyMissingError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(503, "DEPENDENCY_MISSING", e.message) except Exception: - if tmp_path and os.path.exists(tmp_path): - os.remove(tmp_path) - logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) - return _error_response(500, "INTERNAL", "Unexpected server error.") + delete_temp_file_if_exists(parsed.tmp_path) + logging.exception("upload_asset failed for owner_id=%s", owner_id) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + + payload = schemas_out.AssetCreated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash, + size=int(result.asset.size_bytes), + mime_type=result.asset.mime_type, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + preview_id=result.ref.preview_id, + created_at=result.ref.created_at, + last_access_time=result.ref.last_access_time, + created_new=result.created_new, + ) + status = 201 if result.created_new else 200 + return web.json_response(payload.model_dump(mode="json"), status=status) @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") -async def update_asset(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) +@_require_assets_feature_enabled +async def update_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) try: body = schemas_in.UpdateAssetBody.model_validate(await request.json()) except ValidationError as ve: - return _validation_error_response("INVALID_BODY", ve) + return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.update_asset( - asset_info_id=asset_info_id, + result = update_asset_metadata( + reference_id=reference_id, name=body.name, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), ) - except (ValueError, PermissionError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + payload = schemas_out.AssetUpdated( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + tags=result.tags, + user_metadata=result.ref.user_metadata or {}, + updated_at=result.ref.updated_at, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "update_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "update_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") -async def delete_asset(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) - delete_content = request.query.get("delete_content") - delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} +@_require_assets_feature_enabled +async def delete_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) + delete_content_param = request.query.get("delete_content") + delete_content = ( + False + if delete_content_param is None + else delete_content_param.lower() not in {"0", "false", "no"} + ) try: - deleted = manager.delete_asset_reference( - asset_info_id=asset_info_id, + deleted = delete_asset_reference( + reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), delete_content_if_orphan=delete_content, ) except Exception: logging.exception( - "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "delete_asset_reference failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") if not deleted: - return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"AssetReference {reference_id} not found." + ) return web.Response(status=204) @ROUTES.get("/api/tags") +@_require_assets_feature_enabled async def get_tags(request: web.Request) -> web.Response: """ GET request to list all tags based on query parameters. @@ -415,12 +532,14 @@ async def get_tags(request: web.Request) -> web.Response: try: query = schemas_in.TagsListQuery.model_validate(query_map) except ValidationError as e: - return web.json_response( - {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}}, - status=400, + return _build_error_response( + 400, + "INVALID_QUERY", + "Invalid query parameters", + {"errors": json.loads(e.json())}, ) - result = manager.list_tags( + rows, total = list_tags( prefix=query.prefix, limit=query.limit, offset=query.offset, @@ -428,87 +547,212 @@ async def get_tags(request: web.Request) -> web.Response: include_zero=query.include_zero, owner_id=USER_MANAGER.get_request_user_id(request), ) - return web.json_response(result.model_dump(mode="json")) + + tags = [ + schemas_out.TagUsage(name=name, count=count, type=tag_type) + for (name, tag_type, count) in rows + ] + payload = schemas_out.TagsList( + tags=tags, total=total, has_more=(query.offset + len(tags)) < total + ) + return web.json_response(payload.model_dump(mode="json")) @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def add_asset_tags(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - payload = await request.json() - data = schemas_in.TagsAdd.model_validate(payload) + json_payload = await request.json() + data = schemas_in.TagsAdd.model_validate(json_payload) except ValidationError as ve: - return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags add.", + {"errors": ve.errors()}, + ) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.add_tags_to_asset( - asset_info_id=asset_info_id, + result = apply_tags( + reference_id=reference_id, tags=data.tags, origin="manual", owner_id=USER_MANAGER.get_request_user_id(request), ) - except (ValueError, PermissionError) as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + payload = schemas_out.TagsAdd( + added=result.added, + already_present=result.already_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "add_tags_to_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def delete_asset_tags(request: web.Request) -> web.Response: - asset_info_id = str(uuid.UUID(request.match_info["id"])) + reference_id = str(uuid.UUID(request.match_info["id"])) try: - payload = await request.json() - data = schemas_in.TagsRemove.model_validate(payload) + json_payload = await request.json() + data = schemas_in.TagsRemove.model_validate(json_payload) except ValidationError as ve: - return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags remove.", + {"errors": ve.errors()}, + ) except Exception: - return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: - result = manager.remove_tags_from_asset( - asset_info_id=asset_info_id, + result = remove_tags( + reference_id=reference_id, tags=data.tags, owner_id=USER_MANAGER.get_request_user_id(request), ) + payload = schemas_out.TagsRemove( + removed=result.removed, + not_present=result.not_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) except ValueError as ve: - return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) except Exception: logging.exception( - "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", - asset_info_id, + "remove_tags_from_asset failed for reference_id=%s, owner_id=%s", + reference_id, USER_MANAGER.get_request_user_id(request), ) - return _error_response(500, "INTERNAL", "Unexpected server error.") + return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json"), status=200) @ROUTES.post("/api/assets/seed") -async def seed_assets_endpoint(request: web.Request) -> web.Response: - """Trigger asset seeding for specified roots (models, input, output).""" +@_require_assets_feature_enabled +async def seed_assets(request: web.Request) -> web.Response: + """Trigger asset seeding for specified roots (models, input, output). + + Query params: + wait: If "true", block until scan completes (synchronous behavior for tests) + + Returns: + 202 Accepted if scan started + 409 Conflict if scan already running + 200 OK with final stats if wait=true + """ try: payload = await request.json() roots = payload.get("roots", ["models", "input", "output"]) except Exception: roots = ["models", "input", "output"] - valid_roots = [r for r in roots if r in ("models", "input", "output")] + valid_roots = tuple(r for r in roots if r in ("models", "input", "output")) if not valid_roots: - return _error_response(400, "INVALID_BODY", "No valid roots specified") + return _build_error_response(400, "INVALID_BODY", "No valid roots specified") + wait_param = request.query.get("wait", "").lower() + should_wait = wait_param in ("true", "1", "yes") + + started = asset_seeder.start(roots=valid_roots) + if not started: + return web.json_response({"status": "already_running"}, status=409) + + if should_wait: + await asyncio.to_thread(asset_seeder.wait) + status = asset_seeder.get_status() + return web.json_response( + { + "status": "completed", + "progress": { + "scanned": status.progress.scanned if status.progress else 0, + "total": status.progress.total if status.progress else 0, + "created": status.progress.created if status.progress else 0, + "skipped": status.progress.skipped if status.progress else 0, + }, + "errors": status.errors, + }, + status=200, + ) + + return web.json_response({"status": "started"}, status=202) + + +@ROUTES.get("/api/assets/seed/status") +@_require_assets_feature_enabled +async def get_seed_status(request: web.Request) -> web.Response: + """Get current scan status and progress.""" + status = asset_seeder.get_status() + return web.json_response( + { + "state": status.state.value, + "progress": { + "scanned": status.progress.scanned, + "total": status.progress.total, + "created": status.progress.created, + "skipped": status.progress.skipped, + } + if status.progress + else None, + "errors": status.errors, + }, + status=200, + ) + + +@ROUTES.post("/api/assets/seed/cancel") +@_require_assets_feature_enabled +async def cancel_seed(request: web.Request) -> web.Response: + """Request cancellation of in-progress scan.""" + cancelled = asset_seeder.cancel() + if cancelled: + return web.json_response({"status": "cancelling"}, status=200) + return web.json_response({"status": "idle"}, status=200) + + +@ROUTES.post("/api/assets/prune") +@_require_assets_feature_enabled +async def mark_missing_assets(request: web.Request) -> web.Response: + """Mark assets as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and metadata + are preserved, but references are flagged as missing. They can be + restored if the file reappears in a future scan. + + Returns: + 200 OK with count of marked assets + 409 Conflict if a scan is currently running + """ try: - seed_assets(tuple(valid_roots)) - except Exception: - logging.exception("seed_assets failed for roots=%s", valid_roots) - return _error_response(500, "INTERNAL", "Seed operation failed") - - return web.json_response({"seeded": valid_roots}, status=200) + marked = asset_seeder.mark_missing_outside_prefixes() + except ScanInProgressError: + return web.json_response( + {"status": "scan_running", "marked": 0}, + status=409, + ) + return web.json_response({"status": "completed", "marked": marked}, status=200) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 6707ffb0c..d255c938e 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,6 +1,8 @@ import json +from dataclasses import dataclass from typing import Any, Literal +from app.assets.helpers import validate_blake3_hash from pydantic import ( BaseModel, ConfigDict, @@ -10,6 +12,41 @@ from pydantic import ( model_validator, ) + +class UploadError(Exception): + """Error during upload parsing with HTTP status and code.""" + + def __init__(self, status: int, code: str, message: str): + super().__init__(message) + self.status = status + self.code = code + self.message = message + + +class AssetValidationError(Exception): + """Validation error in asset processing (invalid tags, metadata, etc.).""" + + def __init__(self, code: str, message: str): + super().__init__(message) + self.code = code + self.message = message + + +@dataclass +class ParsedUpload: + """Result of parsing a multipart upload request.""" + + file_present: bool + file_written: int + file_client_name: str | None + tmp_path: str | None + tags_raw: list[str] + provided_name: str | None + user_metadata_raw: str | None + provided_hash: str | None + provided_hash_exists: bool | None + + class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) @@ -21,7 +58,9 @@ class ListAssetsQuery(BaseModel): limit: conint(ge=1, le=500) = 20 offset: conint(ge=0) = 0 - sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = ( + "created_at" + ) order: Literal["asc", "desc"] = "desc" @field_validator("include_tags", "exclude_tags", mode="before") @@ -61,7 +100,7 @@ class UpdateAssetBody(BaseModel): user_metadata: dict[str, Any] | None = None @model_validator(mode="after") - def _at_least_one(self): + def _validate_at_least_one_field(self): if self.name is None and self.user_metadata is None: raise ValueError("Provide at least one of: name, user_metadata.") return self @@ -78,19 +117,11 @@ class CreateFromHashBody(BaseModel): @field_validator("hash") @classmethod def _require_blake3(cls, v): - s = (v or "").strip().lower() - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return s + return validate_blake3_hash(v or "") @field_validator("tags", mode="before") @classmethod - def _tags_norm(cls, v): + def _normalize_tags_field(cls, v): if v is None: return [] if isinstance(v, list): @@ -154,15 +185,16 @@ class TagsRemove(TagsAdd): class UploadAssetSpec(BaseModel): """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); - if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + if root == 'models', second must be a valid category - name: display name - user_metadata: arbitrary JSON object (optional) - - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + - hash: optional canonical 'blake3:' for validation / fast-path - Files created via this endpoint are stored on disk using the **content hash** as the filename stem - and the original extension is preserved when available. + Files are stored using the content hash as filename stem. """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) tags: list[str] = Field(..., min_length=1) @@ -175,17 +207,10 @@ class UploadAssetSpec(BaseModel): def _parse_hash(cls, v): if v is None: return None - s = str(v).strip().lower() + s = str(v).strip() if not s: return None - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return f"{algo}:{digest}" + return validate_blake3_hash(s) @field_validator("tags", mode="before") @classmethod @@ -260,5 +285,7 @@ class UploadAssetSpec(BaseModel): raise ValueError("first tag must be one of: models, input, output") if root == "models": if len(self.tags) < 2: - raise ValueError("models uploads require a category tag as the second tag") + raise ValueError( + "models uploads require a category tag as the second tag" + ) return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index b6fb3da0c..f36447856 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -19,7 +19,7 @@ class AssetSummary(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("created_at", "updated_at", "last_access_time") - def _ser_dt(self, v: datetime | None, _info): + def _serialize_datetime(self, v: datetime | None, _info): return v.isoformat() if v else None @@ -40,7 +40,7 @@ class AssetUpdated(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("updated_at") - def _ser_updated(self, v: datetime | None, _info): + def _serialize_updated_at(self, v: datetime | None, _info): return v.isoformat() if v else None @@ -59,7 +59,7 @@ class AssetDetail(BaseModel): model_config = ConfigDict(from_attributes=True) @field_serializer("created_at", "last_access_time") - def _ser_dt(self, v: datetime | None, _info): + def _serialize_datetime(self, v: datetime | None, _info): return v.isoformat() if v else None diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py new file mode 100644 index 000000000..721c12f4d --- /dev/null +++ b/app/assets/api/upload.py @@ -0,0 +1,171 @@ +import logging +import os +import uuid +from typing import Callable + +from aiohttp import web + +import folder_paths +from app.assets.api.schemas_in import ParsedUpload, UploadError +from app.assets.helpers import validate_blake3_hash + + +def normalize_and_validate_hash(s: str) -> str: + """Validate and normalize a hash string. + + Returns canonical 'blake3:' or raises UploadError. + """ + try: + return validate_blake3_hash(s) + except ValueError: + raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") + + +async def parse_multipart_upload( + request: web.Request, + check_hash_exists: Callable[[str], bool], +) -> ParsedUpload: + """ + Parse a multipart/form-data upload request. + + Args: + request: The aiohttp request + check_hash_exists: Callable(hash_str) -> bool to check if a hash exists + + Returns: + ParsedUpload with parsed fields and temp file path + + Raises: + UploadError: On validation or I/O errors + """ + if not (request.content_type or "").lower().startswith("multipart/"): + raise UploadError( + 415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads." + ) + + reader = await request.multipart() + + file_present = False + file_client_name: str | None = None + tags_raw: list[str] = [] + provided_name: str | None = None + user_metadata_raw: str | None = None + provided_hash: str | None = None + provided_hash_exists: bool | None = None + + file_written = 0 + tmp_path: str | None = None + + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + raise UploadError( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + + if s: + provided_hash = normalize_and_validate_hash(s) + try: + provided_hash_exists = check_hash_exists(provided_hash) + except Exception as e: + logging.exception( + "check_hash_exists failed for hash=%s: %s", provided_hash, e + ) + raise UploadError( + 500, + "HASH_CHECK_FAILED", + "Backend error while checking asset hash.", + ) + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # Hash exists - drain file but don't write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file." + ) + continue + + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + delete_temp_file_if_exists(tmp_path) + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file." + ) + + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + if not file_present and not (provided_hash and provided_hash_exists): + raise UploadError( + 400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'." + ) + + if ( + file_present + and file_written == 0 + and not (provided_hash and provided_hash_exists) + ): + delete_temp_file_if_exists(tmp_path) + raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + return ParsedUpload( + file_present=file_present, + file_written=file_written, + file_client_name=file_client_name, + tmp_path=tmp_path, + tags_raw=tags_raw, + provided_name=provided_name, + user_metadata_raw=user_metadata_raw, + provided_hash=provided_hash, + provided_hash_exists=provided_hash_exists, + ) + + +def delete_temp_file_if_exists(tmp_path: str | None) -> None: + """Safely remove a temp file and its parent directory if empty.""" + if tmp_path: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError as e: + logging.debug("Failed to delete temp file %s: %s", tmp_path, e) + try: + parent = os.path.dirname(tmp_path) + if parent and os.path.isdir(parent): + os.rmdir(parent) # only succeeds if empty + except OSError: + pass diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py deleted file mode 100644 index c7b75290a..000000000 --- a/app/assets/database/bulk_ops.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -import uuid -import sqlalchemy -from typing import Iterable -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import utcnow -from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta - -MAX_BIND_PARAMS = 800 - -def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: - if not rows: - return [] - rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) - for i in range(0, len(rows), rows_per_stmt): - yield rows[i:i + rows_per_stmt] - -def _iter_chunks(seq, n: int): - for i in range(0, len(seq), n): - yield seq[i:i + n] - -def _rows_per_stmt(cols: int) -> int: - return max(1, MAX_BIND_PARAMS // max(1, cols)) - - -def seed_from_paths_batch( - session: Session, - *, - specs: list[dict], - owner_id: str = "", -) -> dict: - """Each spec is a dict with keys: - - abs_path: str - - size_bytes: int - - mtime_ns: int - - info_name: str - - tags: list[str] - - fname: Optional[str] - """ - if not specs: - return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} - - now = utcnow() - asset_rows: list[dict] = [] - state_rows: list[dict] = [] - path_to_asset: dict[str, str] = {} - asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row - path_list: list[str] = [] - - for sp in specs: - ap = os.path.abspath(sp["abs_path"]) - aid = str(uuid.uuid4()) - iid = str(uuid.uuid4()) - path_list.append(ap) - path_to_asset[ap] = aid - - asset_rows.append( - { - "id": aid, - "hash": None, - "size_bytes": sp["size_bytes"], - "mime_type": None, - "created_at": now, - } - ) - state_rows.append( - { - "asset_id": aid, - "file_path": ap, - "mtime_ns": sp["mtime_ns"], - } - ) - asset_to_info[aid] = { - "id": iid, - "owner_id": owner_id, - "name": sp["info_name"], - "asset_id": aid, - "preview_id": None, - "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, - "created_at": now, - "updated_at": now, - "last_access_time": now, - "_tags": sp["tags"], - "_filename": sp["fname"], - } - - # insert all seed Assets (hash=NULL) - ins_asset = sqlite.insert(Asset) - for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): - session.execute(ins_asset, chunk) - - # try to claim AssetCacheState (file_path) - # Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted - ins_state = ( - sqlite.insert(AssetCacheState) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): - session.execute(ins_state, chunk) - - # Query to find which of our paths won (were actually inserted) - winners_by_path: set[str] = set() - for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS): - result = session.execute( - sqlalchemy.select(AssetCacheState.file_path) - .where(AssetCacheState.file_path.in_(chunk)) - .where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk])) - ) - winners_by_path.update(result.scalars().all()) - - all_paths_set = set(path_list) - losers_by_path = all_paths_set - winners_by_path - lost_assets = [path_to_asset[p] for p in losers_by_path] - if lost_assets: # losers get their Asset removed - for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): - session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk))) - - if not winners_by_path: - return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} - - # insert AssetInfo only for winners - # Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted - winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] - ins_info = ( - sqlite.insert(AssetInfo) - .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) - ) - for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): - session.execute(ins_info, chunk) - - # Query to find which info rows were actually inserted (by matching our generated IDs) - all_info_ids = [row["id"] for row in winner_info_rows] - inserted_info_ids: set[str] = set() - for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS): - result = session.execute( - sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk)) - ) - inserted_info_ids.update(result.scalars().all()) - - # build and insert tag + meta rows for the AssetInfo - tag_rows: list[dict] = [] - meta_rows: list[dict] = [] - if inserted_info_ids: - for row in winner_info_rows: - iid = row["id"] - if iid not in inserted_info_ids: - continue - for t in row["_tags"]: - tag_rows.append({ - "asset_info_id": iid, - "tag_name": t, - "origin": "automatic", - "added_at": now, - }) - if row["_filename"]: - meta_rows.append( - { - "asset_info_id": iid, - "key": "filename", - "ordinal": 0, - "val_str": row["_filename"], - "val_num": None, - "val_bool": None, - "val_json": None, - } - ) - - bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) - return { - "inserted_infos": len(inserted_info_ids), - "won_states": len(winners_by_path), - "lost_states": len(losers_by_path), - } - - -def bulk_insert_tags_and_meta( - session: Session, - *, - tag_rows: list[dict], - meta_rows: list[dict], - max_bind_params: int, -) -> None: - """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. - - tag_rows keys: asset_info_id, tag_name, origin, added_at - - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json - """ - if tag_rows: - ins_links = ( - sqlite.insert(AssetInfoTag) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): - session.execute(ins_links, chunk) - if meta_rows: - ins_meta = ( - sqlite.insert(AssetInfoMeta) - .on_conflict_do_nothing( - index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] - ) - ) - for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): - session.execute(ins_meta, chunk) diff --git a/app/assets/database/models.py b/app/assets/database/models.py index 3cd28f68b..03c1c1707 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -2,8 +2,8 @@ from __future__ import annotations import uuid from datetime import datetime - from typing import Any + from sqlalchemy import ( JSON, BigInteger, @@ -16,102 +16,102 @@ from sqlalchemy import ( Numeric, String, Text, - UniqueConstraint, ) from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship -from app.assets.helpers import utcnow -from app.database.models import to_dict, Base +from app.assets.helpers import get_utc_now +from app.database.models import Base class Asset(Base): __tablename__ = "assets" - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) hash: Mapped[str | None] = mapped_column(String(256), nullable=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[str | None] = mapped_column(String(255)) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow + DateTime(timezone=False), nullable=False, default=get_utc_now ) - infos: Mapped[list[AssetInfo]] = relationship( - "AssetInfo", + references: Mapped[list[AssetReference]] = relationship( + "AssetReference", back_populates="asset", - primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), - foreign_keys=lambda: [AssetInfo.asset_id], + primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id), + foreign_keys=lambda: [AssetReference.asset_id], cascade="all,delete-orphan", passive_deletes=True, ) - preview_of: Mapped[list[AssetInfo]] = relationship( - "AssetInfo", + preview_of: Mapped[list[AssetReference]] = relationship( + "AssetReference", back_populates="preview_asset", - primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), - foreign_keys=lambda: [AssetInfo.preview_id], + primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id), + foreign_keys=lambda: [AssetReference.preview_id], viewonly=True, ) - cache_states: Mapped[list[AssetCacheState]] = relationship( - back_populates="asset", - cascade="all, delete-orphan", - passive_deletes=True, - ) - __table_args__ = ( Index("uq_assets_hash", "hash", unique=True), Index("ix_assets_mime_type", "mime_type"), CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - return to_dict(self, include_none=include_none) - def __repr__(self) -> str: return f"" -class AssetCacheState(Base): - __tablename__ = "asset_cache_state" +class AssetReference(Base): + """Unified model combining file cache state and user-facing metadata. - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) - file_path: Mapped[str] = mapped_column(Text, nullable=False) - mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + Each row represents either: + - A filesystem reference (file_path is set) with cache state + - An API-created reference (file_path is NULL) without cache state + """ - asset: Mapped[Asset] = relationship(back_populates="cache_states") + __tablename__ = "asset_references" - __table_args__ = ( - Index("ix_asset_cache_state_file_path", "file_path"), - Index("ix_asset_cache_state_asset_id", "asset_id"), - CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), - UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + asset_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - return to_dict(self, include_none=include_none) + # Cache state fields (from former AssetCacheState) + file_path: Mapped[str | None] = mapped_column(Text, nullable=True) + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - def __repr__(self) -> str: - return f"" - - -class AssetInfo(Base): - __tablename__ = "assets_info" - - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + # Info fields (from former AssetInfo) owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) - preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) - user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + preview_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="SET NULL") + ) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True) + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + last_access_time: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + deleted_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=False), nullable=True, default=None + ) asset: Mapped[Asset] = relationship( "Asset", - back_populates="infos", + back_populates="references", foreign_keys=[asset_id], lazy="selectin", ) @@ -121,51 +121,59 @@ class AssetInfo(Base): foreign_keys=[preview_id], ) - metadata_entries: Mapped[list[AssetInfoMeta]] = relationship( - back_populates="asset_info", + metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( + back_populates="asset_reference", cascade="all,delete-orphan", passive_deletes=True, ) - tag_links: Mapped[list[AssetInfoTag]] = relationship( - back_populates="asset_info", + tag_links: Mapped[list[AssetReferenceTag]] = relationship( + back_populates="asset_reference", cascade="all,delete-orphan", passive_deletes=True, - overlaps="tags,asset_infos", + overlaps="tags,asset_references", ) tags: Mapped[list[Tag]] = relationship( - secondary="asset_info_tags", - back_populates="asset_infos", + secondary="asset_reference_tags", + back_populates="asset_references", lazy="selectin", viewonly=True, - overlaps="tag_links,asset_info_links,asset_infos,tag", + overlaps="tag_links,asset_reference_links,asset_references,tag", ) __table_args__ = ( - UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), - Index("ix_assets_info_owner_name", "owner_id", "name"), - Index("ix_assets_info_owner_id", "owner_id"), - Index("ix_assets_info_asset_id", "asset_id"), - Index("ix_assets_info_name", "name"), - Index("ix_assets_info_created_at", "created_at"), - Index("ix_assets_info_last_access_time", "last_access_time"), + Index("uq_asset_references_file_path", "file_path", unique=True), + Index("ix_asset_references_asset_id", "asset_id"), + Index("ix_asset_references_owner_id", "owner_id"), + Index("ix_asset_references_name", "name"), + Index("ix_asset_references_is_missing", "is_missing"), + Index("ix_asset_references_enrichment_level", "enrichment_level"), + Index("ix_asset_references_created_at", "created_at"), + Index("ix_asset_references_last_access_time", "last_access_time"), + Index("ix_asset_references_deleted_at", "deleted_at"), + Index("ix_asset_references_owner_name", "owner_id", "name"), + CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), ) - def to_dict(self, include_none: bool = False) -> dict[str, Any]: - data = to_dict(self, include_none=include_none) - data["tags"] = [t.name for t in self.tags] - return data - def __repr__(self) -> str: - return f"" + path_part = f" path={self.file_path!r}" if self.file_path else "" + return f"" -class AssetInfoMeta(Base): - __tablename__ = "asset_info_meta" +class AssetReferenceMeta(Base): + __tablename__ = "asset_reference_meta" - asset_info_id: Mapped[str] = mapped_column( - String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, ) key: Mapped[str] = mapped_column(String(256), primary_key=True) ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) @@ -175,36 +183,40 @@ class AssetInfoMeta(Base): val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) - asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries") + asset_reference: Mapped[AssetReference] = relationship( + back_populates="metadata_entries" + ) __table_args__ = ( - Index("ix_asset_info_meta_key", "key"), - Index("ix_asset_info_meta_key_val_str", "key", "val_str"), - Index("ix_asset_info_meta_key_val_num", "key", "val_num"), - Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + Index("ix_asset_reference_meta_key", "key"), + Index("ix_asset_reference_meta_key_val_str", "key", "val_str"), + Index("ix_asset_reference_meta_key_val_num", "key", "val_num"), + Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"), ) -class AssetInfoTag(Base): - __tablename__ = "asset_info_tags" +class AssetReferenceTag(Base): + __tablename__ = "asset_reference_tags" - asset_info_id: Mapped[str] = mapped_column( - String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, ) tag_name: Mapped[str] = mapped_column( String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True ) origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") added_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow + DateTime(timezone=False), nullable=False, default=get_utc_now ) - asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links") - tag: Mapped[Tag] = relationship(back_populates="asset_info_links") + asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links") + tag: Mapped[Tag] = relationship(back_populates="asset_reference_links") __table_args__ = ( - Index("ix_asset_info_tags_tag_name", "tag_name"), - Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + Index("ix_asset_reference_tags_tag_name", "tag_name"), + Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"), ) @@ -214,20 +226,18 @@ class Tag(Base): name: Mapped[str] = mapped_column(String(512), primary_key=True) tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") - asset_info_links: Mapped[list[AssetInfoTag]] = relationship( + asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship( back_populates="tag", - overlaps="asset_infos,tags", + overlaps="asset_references,tags", ) - asset_infos: Mapped[list[AssetInfo]] = relationship( - secondary="asset_info_tags", + asset_references: Mapped[list[AssetReference]] = relationship( + secondary="asset_reference_tags", back_populates="tags", viewonly=True, - overlaps="asset_info_links,tag_links,tags,asset_info", + overlaps="asset_reference_links,tag_links,tags,asset_reference", ) - __table_args__ = ( - Index("ix_tags_tag_type", "tag_type"), - ) + __table_args__ = (Index("ix_tags_tag_type", "tag_type"),) def __repr__(self) -> str: return f"" diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py deleted file mode 100644 index d6b33ec7b..000000000 --- a/app/assets/database/queries.py +++ /dev/null @@ -1,976 +0,0 @@ -import os -import logging -import sqlalchemy as sa -from collections import defaultdict -from datetime import datetime -from typing import Iterable, Any -from sqlalchemy import select, delete, exists, func -from sqlalchemy.dialects import sqlite -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, contains_eager, noload -from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag -from app.assets.helpers import ( - compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow -) -from typing import Sequence - - -def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: - """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" - owner_id = (owner_id or "").strip() - if owner_id == "": - return AssetInfo.owner_id == "" - return AssetInfo.owner_id.in_(["", owner_id]) - - -def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: - """ - Return the best on-disk path among cache states: - 1) Prefer a path that exists with needs_verify == False (already verified). - 2) Otherwise, pick the first path that exists. - 3) Otherwise return empty string. - """ - alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] - if not alive: - return "" - for s in alive: - if not getattr(s, "needs_verify", False): - return s.file_path - return alive[0].file_path - - -def apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name.in_(exclude_tags)) - ) - ) - return stmt - - -def apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: dict | None = None, -) -> sa.sql.Select: - """Apply filters using asset_info_meta projection table.""" - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - return sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - if value is None: - no_row_for_key = sa.not_( - sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - ) - ) - null_row = _exists_for_pred( - key, - AssetInfoMeta.val_json.is_(None), - AssetInfoMeta.val_str.is_(None), - AssetInfoMeta.val_num.is_(None), - AssetInfoMeta.val_bool.is_(None), - ) - return sa.or_(no_row_for_key, null_row) - - if isinstance(value, bool): - return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) - if isinstance(value, (int, float)): - from decimal import Decimal - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetInfoMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetInfoMeta.val_str == value) - return _exists_for_pred(key, AssetInfoMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt - - -def asset_exists_by_hash( - session: Session, - *, - asset_hash: str, -) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - row = ( - session.execute( - select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).first() - return row is not None - - -def asset_info_exists_for_asset_id( - session: Session, - *, - asset_id: str, -) -> bool: - q = ( - select(sa.literal(True)) - .select_from(AssetInfo) - .where(AssetInfo.asset_id == asset_id) - .limit(1) - ) - return (session.execute(q)).first() is not None - - -def get_asset_by_hash( - session: Session, - *, - asset_hash: str, -) -> Asset | None: - return ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - - -def get_asset_info_by_id( - session: Session, - *, - asset_info_id: str, -) -> AssetInfo | None: - return session.get(AssetInfo, asset_info_id) - - -def list_asset_infos_page( - session: Session, - owner_id: str = "", - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", -) -> tuple[list[AssetInfo], dict[str, list[str]], int]: - base = ( - select(AssetInfo) - .join(Asset, Asset.id == AssetInfo.asset_id) - .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) - .where(visible_owner_clause(owner_id)) - ) - - if name_contains: - escaped, esc = escape_like_prefix(name_contains) - base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - - base = apply_tag_filters(base, include_tags, exclude_tags) - base = apply_metadata_filter(base, metadata_filter) - - sort = (sort or "created_at").lower() - order = (order or "desc").lower() - sort_map = { - "name": AssetInfo.name, - "created_at": AssetInfo.created_at, - "updated_at": AssetInfo.updated_at, - "last_access_time": AssetInfo.last_access_time, - "size": Asset.size_bytes, - } - sort_col = sort_map.get(sort, AssetInfo.created_at) - sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() - - base = base.order_by(sort_exp).limit(limit).offset(offset) - - count_stmt = ( - select(sa.func.count()) - .select_from(AssetInfo) - .join(Asset, Asset.id == AssetInfo.asset_id) - .where(visible_owner_clause(owner_id)) - ) - if name_contains: - escaped, esc = escape_like_prefix(name_contains) - count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = apply_metadata_filter(count_stmt, metadata_filter) - - total = int((session.execute(count_stmt)).scalar_one() or 0) - - infos = (session.execute(base)).unique().scalars().all() - - id_list: list[str] = [i.id for i in infos] - tag_map: dict[str, list[str]] = defaultdict(list) - if id_list: - rows = session.execute( - select(AssetInfoTag.asset_info_id, Tag.name) - .join(Tag, Tag.name == AssetInfoTag.tag_name) - .where(AssetInfoTag.asset_info_id.in_(id_list)) - .order_by(AssetInfoTag.added_at) - ) - for aid, tag_name in rows.all(): - tag_map[aid].append(tag_name) - - return infos, tag_map, total - - -def fetch_asset_info_asset_and_tags( - session: Session, - asset_info_id: str, - owner_id: str = "", -) -> tuple[AssetInfo, Asset, list[str]] | None: - stmt = ( - select(AssetInfo, Asset, Tag.name) - .join(Asset, Asset.id == AssetInfo.asset_id) - .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) - .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .options(noload(AssetInfo.tags)) - .order_by(Tag.name.asc()) - ) - - rows = (session.execute(stmt)).all() - if not rows: - return None - - first_info, first_asset, _ = rows[0] - tags: list[str] = [] - seen: set[str] = set() - for _info, _asset, tag_name in rows: - if tag_name and tag_name not in seen: - seen.add(tag_name) - tags.append(tag_name) - return first_info, first_asset, tags - - -def fetch_asset_info_and_asset( - session: Session, - *, - asset_info_id: str, - owner_id: str = "", -) -> tuple[AssetInfo, Asset] | None: - stmt = ( - select(AssetInfo, Asset) - .join(Asset, Asset.id == AssetInfo.asset_id) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .limit(1) - .options(noload(AssetInfo.tags)) - ) - row = session.execute(stmt) - pair = row.first() - if not pair: - return None - return pair[0], pair[1] - -def list_cache_states_by_asset_id( - session: Session, *, asset_id: str -) -> Sequence[AssetCacheState]: - return ( - session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_id == asset_id) - .order_by(AssetCacheState.id.asc()) - ) - ).scalars().all() - - -def touch_asset_info_by_id( - session: Session, - *, - asset_info_id: str, - ts: datetime | None = None, - only_if_newer: bool = True, -) -> None: - ts = ts or utcnow() - stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) - if only_if_newer: - stmt = stmt.where( - sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) - ) - session.execute(stmt.values(last_access_time=ts)) - - -def create_asset_info_for_existing_asset( - session: Session, - *, - asset_hash: str, - name: str, - user_metadata: dict | None = None, - tags: Sequence[str] | None = None, - tag_origin: str = "manual", - owner_id: str = "", -) -> AssetInfo: - """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" - now = utcnow() - asset = get_asset_by_hash(session, asset_hash=asset_hash) - if not asset: - raise ValueError(f"Unknown asset hash {asset_hash}") - - info = AssetInfo( - owner_id=owner_id, - name=name, - asset_id=asset.id, - preview_id=None, - created_at=now, - updated_at=now, - last_access_time=now, - ) - try: - with session.begin_nested(): - session.add(info) - session.flush() - except IntegrityError: - existing = ( - session.execute( - select(AssetInfo) - .options(noload(AssetInfo.tags)) - .where( - AssetInfo.asset_id == asset.id, - AssetInfo.name == name, - AssetInfo.owner_id == owner_id, - ) - .limit(1) - ) - ).unique().scalars().first() - if not existing: - raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") - return existing - - # metadata["filename"] hack - new_meta = dict(user_metadata or {}) - computed_filename = None - try: - p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) - if p: - computed_filename = compute_relative_filename(p) - except Exception: - computed_filename = None - if computed_filename: - new_meta["filename"] = computed_filename - if new_meta: - replace_asset_info_metadata_projection( - session, - asset_info_id=info.id, - user_metadata=new_meta, - ) - - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=info.id, - tags=tags, - origin=tag_origin, - ) - return info - - -def set_asset_info_tags( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", -) -> dict: - desired = normalize_tags(tags) - - current = set( - tag_name for (tag_name,) in ( - session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) - ).all() - ) - - to_add = [t for t in desired if t not in current] - to_remove = [t for t in current if t not in desired] - - if to_add: - ensure_tags_exist(session, to_add, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) - for t in to_add - ]) - session.flush() - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) - ) - session.flush() - - return {"added": to_add, "removed": to_remove, "total": desired} - - -def replace_asset_info_metadata_projection( - session: Session, - *, - asset_info_id: str, - user_metadata: dict | None = None, -) -> None: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info.user_metadata = user_metadata or {} - info.updated_at = utcnow() - session.flush() - - session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) - session.flush() - - if not user_metadata: - return - - rows: list[AssetInfoMeta] = [] - for k, v in user_metadata.items(): - for r in project_kv(k, v): - rows.append( - AssetInfoMeta( - asset_info_id=asset_info_id, - key=r["key"], - ordinal=int(r["ordinal"]), - val_str=r.get("val_str"), - val_num=r.get("val_num"), - val_bool=r.get("val_bool"), - val_json=r.get("val_json"), - ) - ) - if rows: - session.add_all(rows) - session.flush() - - -def ingest_fs_asset( - session: Session, - *, - asset_hash: str, - abs_path: str, - size_bytes: int, - mtime_ns: int, - mime_type: str | None = None, - info_name: str | None = None, - owner_id: str = "", - preview_id: str | None = None, - user_metadata: dict | None = None, - tags: Sequence[str] = (), - tag_origin: str = "manual", - require_existing_tags: bool = False, -) -> dict: - """ - Idempotently upsert: - - Asset by content hash (create if missing) - - AssetCacheState(file_path) pointing to asset_id - - Optionally AssetInfo + tag links and metadata projection - Returns flags and ids. - """ - locator = os.path.abspath(abs_path) - now = utcnow() - - if preview_id: - if not session.get(Asset, preview_id): - preview_id = None - - out: dict[str, Any] = { - "asset_created": False, - "asset_updated": False, - "state_created": False, - "state_updated": False, - "asset_info_id": None, - } - - # 1) Asset by hash - asset = ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - if not asset: - vals = { - "hash": asset_hash, - "size_bytes": int(size_bytes), - "mime_type": mime_type, - "created_at": now, - } - res = session.execute( - sqlite.insert(Asset) - .values(**vals) - .on_conflict_do_nothing(index_elements=[Asset.hash]) - ) - if int(res.rowcount or 0) > 0: - out["asset_created"] = True - asset = ( - session.execute( - select(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).scalars().first() - if not asset: - raise RuntimeError("Asset row not found after upsert.") - else: - changed = False - if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: - asset.size_bytes = int(size_bytes) - changed = True - if mime_type and asset.mime_type != mime_type: - asset.mime_type = mime_type - changed = True - if changed: - out["asset_updated"] = True - - # 2) AssetCacheState upsert by file_path (unique) - vals = { - "asset_id": asset.id, - "file_path": locator, - "mtime_ns": int(mtime_ns), - } - ins = ( - sqlite.insert(AssetCacheState) - .values(**vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - - res = session.execute(ins) - if int(res.rowcount or 0) > 0: - out["state_created"] = True - else: - upd = ( - sa.update(AssetCacheState) - .where(AssetCacheState.file_path == locator) - .where( - sa.or_( - AssetCacheState.asset_id != asset.id, - AssetCacheState.mtime_ns.is_(None), - AssetCacheState.mtime_ns != int(mtime_ns), - ) - ) - .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) - ) - res2 = session.execute(upd) - if int(res2.rowcount or 0) > 0: - out["state_updated"] = True - - # 3) Optional AssetInfo + tags + metadata - if info_name: - try: - with session.begin_nested(): - info = AssetInfo( - owner_id=owner_id, - name=info_name, - asset_id=asset.id, - preview_id=preview_id, - created_at=now, - updated_at=now, - last_access_time=now, - ) - session.add(info) - session.flush() - out["asset_info_id"] = info.id - except IntegrityError: - pass - - existing_info = ( - session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_id == asset.id, - AssetInfo.name == info_name, - (AssetInfo.owner_id == owner_id), - ) - .limit(1) - ) - ).unique().scalar_one_or_none() - if not existing_info: - raise RuntimeError("Failed to update or insert AssetInfo.") - - if preview_id and existing_info.preview_id != preview_id: - existing_info.preview_id = preview_id - - existing_info.updated_at = now - if existing_info.last_access_time < now: - existing_info.last_access_time = now - session.flush() - out["asset_info_id"] = existing_info.id - - norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] - if norm and out["asset_info_id"] is not None: - if not require_existing_tags: - ensure_tags_exist(session, norm, tag_type="user") - - existing_tag_names = set( - name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() - ) - missing = [t for t in norm if t not in existing_tag_names] - if missing and require_existing_tags: - raise ValueError(f"Unknown tags: {missing}") - - existing_links = set( - tag_name - for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) - ) - ).all() - ) - to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] - if to_add: - session.add_all( - [ - AssetInfoTag( - asset_info_id=out["asset_info_id"], - tag_name=t, - origin=tag_origin, - added_at=now, - ) - for t in to_add - ] - ) - session.flush() - - # metadata["filename"] hack - if out["asset_info_id"] is not None: - primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) - computed_filename = compute_relative_filename(primary_path) if primary_path else None - - current_meta = existing_info.user_metadata or {} - new_meta = dict(current_meta) - if user_metadata is not None: - for k, v in user_metadata.items(): - new_meta[k] = v - if computed_filename: - new_meta["filename"] = computed_filename - - if new_meta != current_meta: - replace_asset_info_metadata_projection( - session, - asset_info_id=out["asset_info_id"], - user_metadata=new_meta, - ) - - try: - remove_missing_tag_for_asset_id(session, asset_id=asset.id) - except Exception: - logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) - return out - - -def update_asset_info_full( - session: Session, - *, - asset_info_id: str, - name: str | None = None, - tags: Sequence[str] | None = None, - user_metadata: dict | None = None, - tag_origin: str = "manual", - asset_info_row: Any = None, -) -> AssetInfo: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - else: - info = asset_info_row - - touched = False - if name is not None and name != info.name: - info.name = name - touched = True - - computed_filename = None - try: - p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id)) - if p: - computed_filename = compute_relative_filename(p) - except Exception: - computed_filename = None - - if user_metadata is not None: - new_meta = dict(user_metadata) - if computed_filename: - new_meta["filename"] = computed_filename - replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - else: - if computed_filename: - current_meta = info.user_metadata or {} - if current_meta.get("filename") != computed_filename: - new_meta = dict(current_meta) - new_meta["filename"] = computed_filename - replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - - if tags is not None: - set_asset_info_tags( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=tag_origin, - ) - touched = True - - if touched and user_metadata is None: - info.updated_at = utcnow() - session.flush() - - return info - - -def delete_asset_info_by_id( - session: Session, - *, - asset_info_id: str, - owner_id: str, -) -> bool: - stmt = sa.delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - return int((session.execute(stmt)).rowcount or 0) > 0 - - -def list_tags_with_usage( - session: Session, - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - include_zero: bool = True, - order: str = "count_desc", - owner_id: str = "", -) -> tuple[list[tuple[str, str, int]], int]: - counts_sq = ( - select( - AssetInfoTag.tag_name.label("tag_name"), - func.count(AssetInfoTag.asset_info_id).label("cnt"), - ) - .select_from(AssetInfoTag) - .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) - .where(visible_owner_clause(owner_id)) - .group_by(AssetInfoTag.tag_name) - .subquery() - ) - - q = ( - select( - Tag.name, - Tag.tag_type, - func.coalesce(counts_sq.c.cnt, 0).label("count"), - ) - .select_from(Tag) - .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) - ) - - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - q = q.where(Tag.name.like(escaped + "%", escape=esc)) - - if not include_zero: - q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - - if order == "name_asc": - q = q.order_by(Tag.name.asc()) - else: - q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) - - total_q = select(func.count()).select_from(Tag) - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) - if not include_zero: - total_q = total_q.where( - Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) - ) - - rows = (session.execute(q.limit(limit).offset(offset))).all() - total = (session.execute(total_q)).scalar_one() - - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] - return rows_norm, int(total or 0) - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - session.execute(ins) - - -def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: - return [ - tag_name for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - ] - - -def add_tags_to_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", - create_if_missing: bool = True, - asset_info_row: Any = None, -) -> dict: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": [], "already_present": [], "total_tags": total} - - if create_if_missing: - ensure_tags_exist(session, norm, tag_type="user") - - current = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - want = set(norm) - to_add = sorted(want - current) - - if to_add: - with session.begin_nested() as nested: - try: - session.add_all( - [ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_at=utcnow(), - ) - for t in to_add - ] - ) - session.flush() - except IntegrityError: - nested.rollback() - - after = set(get_asset_tags(session, asset_info_id=asset_info_id)) - return { - "added": sorted(((after - current) & want)), - "already_present": sorted(want & current), - "total_tags": sorted(after), - } - - -def remove_tags_from_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], -) -> dict: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": [], "not_present": [], "total_tags": total} - - existing = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - to_remove = sorted(set(t for t in norm if t in existing)) - not_present = sorted(set(t for t in norm if t not in existing)) - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where( - AssetInfoTag.asset_info_id == asset_info_id, - AssetInfoTag.tag_name.in_(to_remove), - ) - ) - session.flush() - - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": to_remove, "not_present": not_present, "total_tags": total} - - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sa.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) - - -def set_asset_info_preview( - session: Session, - *, - asset_info_id: str, - preview_asset_id: str | None = None, -) -> None: - """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - if preview_asset_id is None: - info.preview_id = None - else: - # validate preview asset exists - if not session.get(Asset, preview_asset_id): - raise ValueError(f"Preview Asset {preview_asset_id} not found") - info.preview_id = preview_asset_id - - info.updated_at = utcnow() - session.flush() diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py new file mode 100644 index 000000000..7888d0645 --- /dev/null +++ b/app/assets/database/queries/__init__.py @@ -0,0 +1,121 @@ +from app.assets.database.queries.asset import ( + asset_exists_by_hash, + bulk_insert_assets, + get_asset_by_hash, + get_existing_asset_ids, + reassign_asset_references, + update_asset_hash_and_mime, + upsert_asset, +) +from app.assets.database.queries.asset_reference import ( + CacheStateRow, + UnenrichedReferenceRow, + bulk_insert_references_ignore_conflicts, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + convert_metadata_to_rows, + delete_assets_by_ids, + delete_orphaned_seed_asset, + delete_reference_by_id, + delete_references_by_ids, + fetch_reference_and_asset, + fetch_reference_asset_and_tags, + get_or_create_reference, + get_reference_by_file_path, + get_reference_by_id, + get_reference_with_owner_check, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_references_for_prefixes, + get_unenriched_references, + get_unreferenced_unhashed_asset_ids, + insert_reference, + list_references_by_asset_id, + list_references_page, + mark_references_missing_outside_prefixes, + reference_exists_for_asset_id, + restore_references_by_paths, + set_reference_metadata, + set_reference_preview, + soft_delete_reference_by_id, + update_reference_access_time, + update_reference_name, + update_reference_timestamps, + update_reference_updated_at, + upsert_reference, +) +from app.assets.database.queries.tags import ( + AddTagsResult, + RemoveTagsResult, + SetTagsResult, + add_missing_tag_for_asset_id, + add_tags_to_reference, + bulk_insert_tags_and_meta, + ensure_tags_exist, + get_reference_tags, + list_tags_with_usage, + remove_missing_tag_for_asset_id, + remove_tags_from_reference, + set_reference_tags, + validate_tags_exist, +) + +__all__ = [ + "AddTagsResult", + "CacheStateRow", + "RemoveTagsResult", + "SetTagsResult", + "UnenrichedReferenceRow", + "add_missing_tag_for_asset_id", + "add_tags_to_reference", + "asset_exists_by_hash", + "bulk_insert_assets", + "bulk_insert_references_ignore_conflicts", + "bulk_insert_tags_and_meta", + "bulk_update_enrichment_level", + "bulk_update_is_missing", + "bulk_update_needs_verify", + "convert_metadata_to_rows", + "delete_assets_by_ids", + "delete_orphaned_seed_asset", + "delete_reference_by_id", + "delete_references_by_ids", + "ensure_tags_exist", + "fetch_reference_and_asset", + "fetch_reference_asset_and_tags", + "get_asset_by_hash", + "get_existing_asset_ids", + "get_or_create_reference", + "get_reference_by_file_path", + "get_reference_by_id", + "get_reference_with_owner_check", + "get_reference_ids_by_ids", + "get_reference_tags", + "get_references_by_paths_and_asset_ids", + "get_references_for_prefixes", + "get_unenriched_references", + "get_unreferenced_unhashed_asset_ids", + "insert_reference", + "list_references_by_asset_id", + "list_references_page", + "list_tags_with_usage", + "mark_references_missing_outside_prefixes", + "reassign_asset_references", + "reference_exists_for_asset_id", + "remove_missing_tag_for_asset_id", + "remove_tags_from_reference", + "restore_references_by_paths", + "set_reference_metadata", + "set_reference_preview", + "soft_delete_reference_by_id", + "set_reference_tags", + "update_asset_hash_and_mime", + "update_reference_access_time", + "update_reference_name", + "update_reference_timestamps", + "update_reference_updated_at", + "upsert_asset", + "upsert_reference", + "validate_tags_exist", +] diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py new file mode 100644 index 000000000..a21f5b68f --- /dev/null +++ b/app/assets/database/queries/asset.py @@ -0,0 +1,140 @@ +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.dialects import sqlite +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks + + +def asset_exists_by_hash( + session: Session, + asset_hash: str, +) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)) + .select_from(Asset) + .where(Asset.hash == asset_hash) + .limit(1) + ) + ).first() + return row is not None + + +def get_asset_by_hash( + session: Session, + asset_hash: str, +) -> Asset | None: + return ( + (session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))) + .scalars() + .first() + ) + + +def upsert_asset( + session: Session, + asset_hash: str, + size_bytes: int, + mime_type: str | None = None, +) -> tuple[Asset, bool, bool]: + """Upsert an Asset by hash. Returns (asset, created, updated).""" + vals = {"hash": asset_hash, "size_bytes": int(size_bytes)} + if mime_type: + vals["mime_type"] = mime_type + + ins = ( + sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + .scalars() + .first() + ) + if not asset: + raise RuntimeError("Asset row not found after upsert.") + + updated = False + if not created: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + updated = True + + return asset, created, updated + + +def bulk_insert_assets( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash.""" + if not rows: + return + ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash]) + for chunk in iter_chunks(rows, calculate_rows_per_statement(5)): + session.execute(ins, chunk) + + +def get_existing_asset_ids( + session: Session, + asset_ids: list[str], +) -> set[str]: + """Return the subset of asset_ids that exist in the database.""" + if not asset_ids: + return set() + found: set[str] = set() + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + rows = session.execute( + select(Asset.id).where(Asset.id.in_(chunk)) + ).fetchall() + found.update(row[0] for row in rows) + return found + + +def update_asset_hash_and_mime( + session: Session, + asset_id: str, + asset_hash: str | None = None, + mime_type: str | None = None, +) -> bool: + """Update asset hash and/or mime_type. Returns True if asset was found.""" + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset_hash is not None: + asset.hash = asset_hash + if mime_type is not None: + asset.mime_type = mime_type + return True + + +def reassign_asset_references( + session: Session, + from_asset_id: str, + to_asset_id: str, + reference_id: str, +) -> None: + """Reassign a reference from one asset to another. + + Used when merging a stub asset into an existing asset with the same hash. + """ + ref = session.get(AssetReference, reference_id) + if ref and ref.asset_id == from_asset_id: + ref.asset_id = to_asset_id + + session.flush() diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py new file mode 100644 index 000000000..6524791cc --- /dev/null +++ b/app/assets/database/queries/asset_reference.py @@ -0,0 +1,1033 @@ +"""Query functions for the unified AssetReference table. + +This module replaces the separate asset_info.py and cache_state.py query modules, +providing a unified interface for the merged asset_references table. +""" + +from collections import defaultdict +from datetime import datetime +from decimal import Decimal +from typing import NamedTuple, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, exists, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, noload + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + MAX_BIND_PARAMS, + build_prefix_like_conditions, + build_visible_owner_clause, + calculate_rows_per_statement, + iter_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +def _check_is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def _scalar_to_row(key: str, ordinal: int, value) -> dict: + """Convert a scalar value to a typed projection row.""" + if value is None: + return { + "key": key, + "ordinal": ordinal, + "val_str": None, + "val_num": None, + "val_bool": None, + "val_json": None, + } + if isinstance(value, bool): + return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return {"key": key, "ordinal": ordinal, "val_num": num} + if isinstance(value, str): + return {"key": key, "ordinal": ordinal, "val_str": value} + return {"key": key, "ordinal": ordinal, "val_json": value} + + +def convert_metadata_to_rows(key: str, value) -> list[dict]: + """Turn a metadata key/value into typed projection rows.""" + if value is None: + return [_scalar_to_row(key, 0, None)] + + if _check_is_scalar(value): + return [_scalar_to_row(key, 0, value)] + + if isinstance(value, list): + if all(_check_is_scalar(x) for x in value): + return [_scalar_to_row(key, i, x) for i, x in enumerate(value)] + return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)] + + return [{"key": key, "ordinal": 0, "val_json": value}] + + +def _apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def _apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_reference_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetReferenceMeta.val_json.is_(None), + AssetReferenceMeta.val_str.is_(None), + AssetReferenceMeta.val_num.is_(None), + AssetReferenceMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetReferenceMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetReferenceMeta.val_str == value) + return _exists_for_pred(key, AssetReferenceMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def get_reference_by_id( + session: Session, + reference_id: str, +) -> AssetReference | None: + return session.get(AssetReference, reference_id) + + +def get_reference_with_owner_check( + session: Session, + reference_id: str, + owner_id: str, +) -> AssetReference: + """Fetch a reference and verify ownership. + + Raises: + ValueError: if reference not found or soft-deleted + PermissionError: if owner_id doesn't match + """ + ref = get_reference_by_id(session, reference_id=reference_id) + if not ref or ref.deleted_at is not None: + raise ValueError(f"AssetReference {reference_id} not found") + if ref.owner_id and ref.owner_id != owner_id: + raise PermissionError("not owner") + return ref + + +def get_reference_by_file_path( + session: Session, + file_path: str, +) -> AssetReference | None: + """Get a reference by its file path.""" + return ( + session.execute( + select(AssetReference).where(AssetReference.file_path == file_path).limit(1) + ) + .scalars() + .first() + ) + + +def reference_exists_for_asset_id( + session: Session, + asset_id: str, +) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetReference) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.deleted_at.is_(None)) + .limit(1) + ) + return session.execute(q).first() is not None + + +def insert_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> AssetReference | None: + """Insert a new AssetReference. Returns None if unique constraint violated.""" + now = get_utc_now() + try: + with session.begin_nested(): + ref = AssetReference( + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + except IntegrityError: + return None + + +def get_or_create_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> tuple[AssetReference, bool]: + """Get existing or create new AssetReference. + + For filesystem references (file_path is set), uniqueness is by file_path. + For API references (file_path is None), we look for matching + asset_id + owner_id + name. + + Returns (reference, created). + """ + ref = insert_reference( + session, + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + ) + if ref: + return ref, True + + # Find existing - priority to file_path match, then name match + if file_path: + existing = get_reference_by_file_path(session, file_path) + else: + existing = ( + session.execute( + select(AssetReference) + .where( + AssetReference.asset_id == asset_id, + AssetReference.name == name, + AssetReference.owner_id == owner_id, + AssetReference.file_path.is_(None), + ) + .limit(1) + ) + .unique() + .scalar_one_or_none() + ) + if not existing: + raise RuntimeError("Failed to find AssetReference after insert conflict.") + return existing, False + + +def update_reference_timestamps( + session: Session, + reference: AssetReference, + preview_id: str | None = None, +) -> None: + """Update timestamps and optionally preview_id on existing AssetReference.""" + now = get_utc_now() + if preview_id and reference.preview_id != preview_id: + reference.preview_id = preview_id + reference.updated_at = now + + +def list_references_page( + session: Session, + owner_id: str = "", + limit: int = 100, + offset: int = 0, + name_contains: str | None = None, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + metadata_filter: dict | None = None, + sort: str | None = None, + order: str | None = None, +) -> tuple[list[AssetReference], dict[str, list[str]], int]: + """List references with pagination, filtering, and sorting. + + Returns (references, tag_map, total_count). + """ + base = ( + select(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .options(noload(AssetReference.tags)) + ) + + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + + base = _apply_tag_filters(base, include_tags, exclude_tags) + base = _apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetReference.name, + "created_at": AssetReference.created_at, + "updated_at": AssetReference.updated_at, + "last_access_time": AssetReference.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetReference.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(sa.func.count()) + .select_from(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + ) + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + count_stmt = count_stmt.where( + AssetReference.name.ilike(f"%{escaped}%", escape=esc) + ) + count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + + total = int(session.execute(count_stmt).scalar_one() or 0) + refs = session.execute(base).unique().scalars().all() + + id_list: list[str] = [r.id for r in refs] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = session.execute( + select(AssetReferenceTag.asset_reference_id, Tag.name) + .join(Tag, Tag.name == AssetReferenceTag.tag_name) + .where(AssetReferenceTag.asset_reference_id.in_(id_list)) + .order_by(AssetReferenceTag.added_at) + ) + for ref_id, tag_name in rows.all(): + tag_map[ref_id].append(tag_name) + + return list(refs), tag_map, total + + +def fetch_reference_asset_and_tags( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset, list[str]] | None: + stmt = ( + select(AssetReference, Asset, Tag.name) + .join(Asset, Asset.id == AssetReference.asset_id) + .join( + AssetReferenceTag, + AssetReferenceTag.asset_reference_id == AssetReference.id, + isouter=True, + ) + .join(Tag, Tag.name == AssetReferenceTag.tag_name, isouter=True) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .options(noload(AssetReference.tags)) + .order_by(Tag.name.asc()) + ) + + rows = session.execute(stmt).all() + if not rows: + return None + + first_ref, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _ref, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_ref, first_asset, tags + + +def fetch_reference_and_asset( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset] | None: + stmt = ( + select(AssetReference, Asset) + .join(Asset, Asset.id == AssetReference.asset_id) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetReference.tags)) + ) + pair = session.execute(stmt).first() + if not pair: + return None + return pair[0], pair[1] + + +def update_reference_access_time( + session: Session, + reference_id: str, + ts: datetime | None = None, + only_if_newer: bool = True, +) -> None: + ts = ts or get_utc_now() + stmt = sa.update(AssetReference).where(AssetReference.id == reference_id) + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetReference.last_access_time.is_(None), + AssetReference.last_access_time < ts, + ) + ) + session.execute(stmt.values(last_access_time=ts)) + + +def update_reference_name( + session: Session, + reference_id: str, + name: str, +) -> None: + """Update the name of an AssetReference.""" + now = get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(name=name, updated_at=now) + ) + + +def update_reference_updated_at( + session: Session, + reference_id: str, + ts: datetime | None = None, +) -> None: + """Update the updated_at timestamp of an AssetReference.""" + ts = ts or get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(updated_at=ts) + ) + + +def set_reference_metadata( + session: Session, + reference_id: str, + user_metadata: dict | None = None, +) -> None: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + ref.user_metadata = user_metadata or {} + ref.updated_at = get_utc_now() + session.flush() + + session.execute( + delete(AssetReferenceMeta).where( + AssetReferenceMeta.asset_reference_id == reference_id + ) + ) + session.flush() + + if not user_metadata: + return + + rows: list[AssetReferenceMeta] = [] + for k, v in user_metadata.items(): + for r in convert_metadata_to_rows(k, v): + rows.append( + AssetReferenceMeta( + asset_reference_id=reference_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + +def delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + stmt = sa.delete(AssetReference).where( + AssetReference.id == reference_id, + build_visible_owner_clause(owner_id), + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def soft_delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + """Mark a reference as soft-deleted by setting deleted_at timestamp. + + Returns True if the reference was found and marked deleted. + """ + now = get_utc_now() + stmt = ( + sa.update(AssetReference) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .values(deleted_at=now) + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def set_reference_preview( + session: Session, + reference_id: str, + preview_asset_id: str | None = None, +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + if preview_asset_id is None: + ref.preview_id = None + else: + if not session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + ref.preview_id = preview_asset_id + + ref.updated_at = get_utc_now() + session.flush() + + +class CacheStateRow(NamedTuple): + """Row from reference query with cache state data.""" + + reference_id: str + file_path: str + mtime_ns: int | None + needs_verify: bool + asset_id: str + asset_hash: str | None + size_bytes: int | None + + +def list_references_by_asset_id( + session: Session, + asset_id: str, +) -> Sequence[AssetReference]: + return ( + session.execute( + select(AssetReference) + .where(AssetReference.asset_id == asset_id) + .order_by(AssetReference.id.asc()) + ) + .scalars() + .all() + ) + + +def upsert_reference( + session: Session, + asset_id: str, + file_path: str, + name: str, + mtime_ns: int, + owner_id: str = "", +) -> tuple[bool, bool]: + """Upsert a reference by file_path. Returns (created, updated). + + Also restores references that were previously marked as missing. + """ + now = get_utc_now() + vals = { + "asset_id": asset_id, + "file_path": file_path, + "name": name, + "owner_id": owner_id, + "mtime_ns": int(mtime_ns), + "is_missing": False, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ins = ( + sqlite.insert(AssetReference) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetReference.file_path]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + if created: + return True, False + + upd = ( + sa.update(AssetReference) + .where(AssetReference.file_path == file_path) + .where( + sa.or_( + AssetReference.asset_id != asset_id, + AssetReference.mtime_ns.is_(None), + AssetReference.mtime_ns != int(mtime_ns), + AssetReference.is_missing == True, # noqa: E712 + AssetReference.deleted_at.isnot(None), + ) + ) + .values( + asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, + deleted_at=None, updated_at=now, + ) + ) + res2 = session.execute(upd) + updated = int(res2.rowcount or 0) > 0 + return False, updated + + +def mark_references_missing_outside_prefixes( + session: Session, + valid_prefixes: list[str], +) -> int: + """Mark references as missing when file_path doesn't match any valid prefix. + + Returns number of references marked as missing. + """ + if not valid_prefixes: + return 0 + + conds = build_prefix_like_conditions(valid_prefixes) + matches_valid_prefix = sa.or_(*conds) + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(~matches_valid_prefix) + .where(AssetReference.is_missing == False) # noqa: E712 + .values(is_missing=True) + ) + return result.rowcount + + +def restore_references_by_paths(session: Session, file_paths: list[str]) -> int: + """Restore references that were previously marked as missing. + + Returns number of references restored. + """ + if not file_paths: + return 0 + + total = 0 + for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.in_(chunk)) + .where(AssetReference.is_missing == True) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .values(is_missing=False) + ) + total += result.rowcount + return total + + +def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]: + """Get IDs of unhashed assets (hash=None) with no active references. + + An asset is considered unreferenced if it has no references, + or all its references are marked as missing. + + Returns list of asset IDs that are unreferenced. + """ + active_ref_exists = ( + sa.select(sa.literal(1)) + .where(AssetReference.asset_id == Asset.id) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .correlate(Asset) + .exists() + ) + unreferenced_subq = sa.select(Asset.id).where( + Asset.hash.is_(None), ~active_ref_exists + ) + return [row[0] for row in session.execute(unreferenced_subq).all()] + + +def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int: + """Delete assets and their references by ID. + + Returns number of assets deleted. + """ + if not asset_ids: + return 0 + total = 0 + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk)) + ) + result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk))) + total += result.rowcount + return total + + +def get_references_for_prefixes( + session: Session, + prefixes: list[str], + *, + include_missing: bool = False, +) -> list[CacheStateRow]: + """Get all references with file paths matching any of the given prefixes. + + Args: + session: Database session + prefixes: List of absolute directory prefixes to match + include_missing: If False (default), exclude references marked as missing + + Returns: + List of cache state rows with joined asset data + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.file_path, + AssetReference.mtime_ns, + AssetReference.needs_verify, + AssetReference.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + ) + + if not include_missing: + query = query.where(AssetReference.is_missing == False) # noqa: E712 + + rows = session.execute( + query.order_by(AssetReference.asset_id.asc(), AssetReference.id.asc()) + ).all() + + return [ + CacheStateRow( + reference_id=row[0], + file_path=row[1], + mtime_ns=row[2], + needs_verify=row[3], + asset_id=row[4], + asset_hash=row[5], + size_bytes=int(row[6]) if row[6] is not None else None, + ) + for row in rows + ] + + +def bulk_update_needs_verify( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set needs_verify flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(needs_verify=value) + ) + total += result.rowcount + return total + + +def bulk_update_is_missing( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set is_missing flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(is_missing=value) + ) + total += result.rowcount + return total + + +def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: + """Delete references by their IDs. + + Returns: Number of rows deleted + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.delete(AssetReference).where(AssetReference.id.in_(chunk)) + ) + total += result.rowcount + return total + + +def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool: + """Delete a seed asset (hash is None) and its references. + + Returns: True if asset was deleted, False if not found or has a hash + """ + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset.hash is not None: + return False + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id == asset_id) + ) + session.delete(asset) + return True + + +class UnenrichedReferenceRow(NamedTuple): + """Row for references needing enrichment.""" + + reference_id: str + asset_id: str + file_path: str + enrichment_level: int + + +def get_unenriched_references( + session: Session, + prefixes: list[str], + max_level: int = 0, + limit: int = 1000, +) -> list[UnenrichedReferenceRow]: + """Get references that need enrichment (enrichment_level <= max_level). + + Args: + session: Database session + prefixes: List of absolute directory prefixes to scan + max_level: Maximum enrichment level to include (0=stubs, 1=metadata done) + limit: Maximum number of rows to return + + Returns: + List of unenriched reference rows with file paths + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.asset_id, + AssetReference.file_path, + AssetReference.enrichment_level, + ) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.enrichment_level <= max_level) + .order_by(AssetReference.id.asc()) + .limit(limit) + ) + + rows = session.execute(query).all() + return [ + UnenrichedReferenceRow( + reference_id=row[0], + asset_id=row[1], + file_path=row[2], + enrichment_level=row[3], + ) + for row in rows + ] + + +def bulk_update_enrichment_level( + session: Session, + reference_ids: list[str], + level: int, +) -> int: + """Update enrichment level for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(reference_ids)) + .values(enrichment_level=level) + ) + return result.rowcount + + +def bulk_insert_references_ignore_conflicts( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert reference rows with ON CONFLICT DO NOTHING on file_path. + + Each dict should have: id, asset_id, file_path, name, owner_id, mtime_ns, etc. + The is_missing field is automatically set to False for new inserts. + """ + if not rows: + return + enriched_rows = [{**row, "is_missing": False} for row in rows] + ins = sqlite.insert(AssetReference).on_conflict_do_nothing( + index_elements=[AssetReference.file_path] + ) + for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(14)): + session.execute(ins, chunk) + + +def get_references_by_paths_and_asset_ids( + session: Session, + path_to_asset: dict[str, str], +) -> set[str]: + """Query references to find paths where our asset_id won the insert. + + Args: + path_to_asset: Mapping of file_path -> asset_id we tried to insert + + Returns: + Set of file_paths where our asset_id is present + """ + if not path_to_asset: + return set() + + pairs = list(path_to_asset.items()) + winners: set[str] = set() + + # Each pair uses 2 bind params, so chunk at MAX_BIND_PARAMS // 2 + for chunk in iter_chunks(pairs, MAX_BIND_PARAMS // 2): + pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_( + chunk + ) + result = session.execute( + select(AssetReference.file_path).where(pairwise) + ) + winners.update(result.scalars().all()) + + return winners + + +def get_reference_ids_by_ids( + session: Session, + reference_ids: list[str], +) -> set[str]: + """Query to find which reference IDs exist in the database.""" + if not reference_ids: + return set() + + found: set[str] = set() + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + select(AssetReference.id).where(AssetReference.id.in_(chunk)) + ) + found.update(result.scalars().all()) + return found diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py new file mode 100644 index 000000000..194c39a1e --- /dev/null +++ b/app/assets/database/queries/common.py @@ -0,0 +1,54 @@ +"""Shared utilities for database query modules.""" + +import os +from typing import Iterable + +import sqlalchemy as sa + +from app.assets.database.models import AssetReference +from app.assets.helpers import escape_sql_like_string + +MAX_BIND_PARAMS = 800 + + +def calculate_rows_per_statement(cols: int) -> int: + """Calculate how many rows can fit in one statement given column count.""" + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def iter_chunks(seq, n: int): + """Yield successive n-sized chunks from seq.""" + for i in range(0, len(seq), n): + yield seq[i : i + n] + + +def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]: + """Yield chunks of rows sized to fit within bind param limits.""" + if not rows: + return + yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row)) + + +def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. + + Owner-less rows are visible to everyone. + """ + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetReference.owner_id == "" + return AssetReference.owner_id.in_(["", owner_id]) + + +def build_prefix_like_conditions( + prefixes: list[str], +) -> list[sa.sql.ColumnElement]: + """Build LIKE conditions for matching file paths under directory prefixes.""" + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_sql_like_string(base) + conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + return conds diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py new file mode 100644 index 000000000..8b25fee67 --- /dev/null +++ b/app/assets/database/queries/tags.py @@ -0,0 +1,356 @@ +from dataclasses import dataclass +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, func, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + build_visible_owner_clause, + iter_row_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +@dataclass(frozen=True) +class AddTagsResult: + added: list[str] + already_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class RemoveTagsResult: + removed: list[str] + not_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class SetTagsResult: + added: list[str] + removed: list[str] + total: list[str] + + +def validate_tags_exist(session: Session, tags: list[str]) -> None: + """Raise ValueError if any of the given tag names do not exist.""" + existing_tag_names = set( + name + for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() + ) + missing = [t for t in tags if t not in existing_tag_names] + if missing: + raise ValueError(f"Unknown tags: {missing}") + + +def ensure_tags_exist( + session: Session, names: Iterable[str], tag_type: str = "user" +) -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_reference_tags(session: Session, reference_id: str) -> list[str]: + return [ + tag_name + for (tag_name,) in ( + session.execute( + select(AssetReferenceTag.tag_name).where( + AssetReferenceTag.asset_reference_id == reference_id + ) + ) + ).all() + ] + + +def set_reference_tags( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> SetTagsResult: + desired = normalize_tags(tags) + + current = set(get_reference_tags(session, reference_id)) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + return SetTagsResult(added=to_add, removed=to_remove, total=desired) + + +def add_tags_to_reference( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + reference_row: AssetReference | None = None, +) -> AddTagsResult: + if not reference_row: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return AddTagsResult(added=[], already_present=[], total_tags=total) + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = set(get_reference_tags(session, reference_id)) + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_reference_tags(session, reference_id=reference_id)) + return AddTagsResult( + added=sorted(((after - current) & want)), + already_present=sorted(want & current), + total_tags=sorted(after), + ) + + +def remove_tags_from_reference( + session: Session, + reference_id: str, + tags: Sequence[str], +) -> RemoveTagsResult: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=[], not_present=[], total_tags=total) + + existing = set(get_reference_tags(session, reference_id)) + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total) + + +def add_missing_tag_for_asset_id( + session: Session, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sa.select( + AssetReference.id.label("asset_reference_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(get_utc_now()).label("added_at"), + ) + .where(AssetReference.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == "missing") + ) + ) + ) + ) + session.execute( + sqlite.insert(AssetReferenceTag) + .from_select( + ["asset_reference_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + ) + + +def remove_missing_tag_for_asset_id( + session: Session, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id.in_( + sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id) + ), + AssetReferenceTag.tag_name == "missing", + ) + ) + + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetReferenceTag.tag_name.label("tag_name"), + func.count(AssetReferenceTag.asset_reference_id).label("cnt"), + ) + .select_from(AssetReferenceTag) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + visible_tags_sq = ( + select(AssetReferenceTag.tag_name) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + ) + total_q = total_q.where(Tag.name.in_(visible_tags_sq)) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +def bulk_insert_tags_and_meta( + session: Session, + tag_rows: list[dict], + meta_rows: list[dict], +) -> None: + """Batch insert into asset_reference_tags and asset_reference_meta. + + Uses ON CONFLICT DO NOTHING. + + Args: + session: Database session + tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at + meta_rows: Dicts with: asset_reference_id, key, ordinal, val_* + """ + if tag_rows: + ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + for chunk in iter_row_chunks(tag_rows, cols_per_row=4): + session.execute(ins_tags, chunk) + + if meta_rows: + ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing( + index_elements=[ + AssetReferenceMeta.asset_reference_id, + AssetReferenceMeta.key, + AssetReferenceMeta.ordinal, + ] + ) + for chunk in iter_row_chunks(meta_rows, cols_per_row=7): + session.execute(ins_meta, chunk) diff --git a/app/assets/database/tags.py b/app/assets/database/tags.py deleted file mode 100644 index 3ab6497c2..000000000 --- a/app/assets/database/tags.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Iterable - -import sqlalchemy -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import normalize_tags, utcnow -from app.assets.database.models import Tag, AssetInfoTag, AssetInfo - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - return session.execute(ins) - -def add_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, - origin: str = "automatic", -) -> None: - select_rows = ( - sqlalchemy.select( - AssetInfo.id.label("asset_info_id"), - sqlalchemy.literal("missing").label("tag_name"), - sqlalchemy.literal(origin).label("origin"), - sqlalchemy.literal(utcnow()).label("added_at"), - ) - .where(AssetInfo.asset_id == asset_id) - .where( - sqlalchemy.not_( - sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) - ) - ) - ) - session.execute( - sqlite.insert(AssetInfoTag) - .from_select( - ["asset_info_id", "tag_name", "origin", "added_at"], - select_rows, - ) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sqlalchemy.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) diff --git a/app/assets/hashing.py b/app/assets/hashing.py deleted file mode 100644 index 4b72084b9..000000000 --- a/app/assets/hashing.py +++ /dev/null @@ -1,75 +0,0 @@ -from blake3 import blake3 -from typing import IO -import os -import asyncio - - -DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB - -# NOTE: this allows hashing different representations of a file-like object -def blake3_hash( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - """ - Returns a BLAKE3 hex digest for ``fp``, which may be: - - a filename (str/bytes) or PathLike - - an open binary file object - If ``fp`` is a file object, it must be opened in **binary** mode and support - ``read``, ``seek``, and ``tell``. The function will seek to the start before - reading and will attempt to restore the original position afterward. - """ - # duck typing to check if input is a file-like object - if hasattr(fp, "read"): - return _hash_file_obj(fp, chunk_size) - - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - -async def blake3_hash_async( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - """Async wrapper for ``blake3_hash_sync``. - Uses a worker thread so the event loop remains responsive. - """ - # If it is a path, open inside the worker thread to keep I/O off the loop. - if hasattr(fp, "read"): - return await asyncio.to_thread(blake3_hash, fp, chunk_size) - - def _worker() -> str: - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - return await asyncio.to_thread(_worker) - - -def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: - """ - Hash an already-open binary file object by streaming in chunks. - - Seeks to the beginning before reading (if supported). - - Restores the original position afterward (if tell/seek are supported). - """ - if chunk_size <= 0: - chunk_size = DEFAULT_CHUNK - - # in case file object is already open and not at the beginning, track so can be restored after hashing - orig_pos = file_obj.tell() - - try: - # seek to the beginning before reading - if orig_pos != 0: - file_obj.seek(0) - - h = blake3() - while True: - chunk = file_obj.read(chunk_size) - if not chunk: - break - h.update(chunk) - return h.hexdigest() - finally: - # restore original position in file object, if needed - if orig_pos != 0: - file_obj.seek(orig_pos) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 5030b123a..3798f3933 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -1,226 +1,42 @@ -import contextlib import os -from decimal import Decimal -from aiohttp import web from datetime import datetime, timezone -from pathlib import Path -from typing import Literal, Any - -import folder_paths +from typing import Sequence -RootType = Literal["models", "input", "output"] -ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") - -def get_query_dict(request: web.Request) -> dict[str, Any]: +def select_best_live_path(states: Sequence) -> str: """ - Gets a dictionary of query parameters from the request. - - 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. """ - query_dict = { - key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) - for key in request.query.keys() - } - return query_dict + alive = [ + s + for s in states + if getattr(s, "file_path", None) and os.path.isfile(s.file_path) + ] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path -def list_tree(base_dir: str) -> list[str]: - out: list[str] = [] - base_abs = os.path.abspath(base_dir) - if not os.path.isdir(base_abs): - return out - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - for name in filenames: - out.append(os.path.abspath(os.path.join(dirpath, name))) - return out -def prefixes_for_root(root: RootType) -> list[str]: - if root == "models": - bases: list[str] = [] - for _bucket, paths in get_comfy_models_folders(): - bases.extend(paths) - return [os.path.abspath(p) for p in bases] - if root == "input": - return [os.path.abspath(folder_paths.get_input_directory())] - if root == "output": - return [os.path.abspath(folder_paths.get_output_directory())] - return [] +def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char in a LIKE prefix. -def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: - """Escapes %, _ and the escape char itself in a LIKE prefix. - Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + Returns (escaped_prefix, escape_char). """ s = s.replace(escape, escape + escape) # escape the escape char first s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards return s, escape -def fast_asset_file_check( - *, - mtime_db: int | None, - size_db: int | None, - stat_result: os.stat_result, -) -> bool: - if mtime_db is None: - return False - actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) - if int(mtime_db) != int(actual_mtime_ns): - return False - sz = int(size_db or 0) - if sz > 0: - return int(stat_result.st_size) == sz - return True -def utcnow() -> datetime: +def get_utc_now() -> datetime: """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" return datetime.now(timezone.utc).replace(tzinfo=None) -def get_comfy_models_folders() -> list[tuple[str, list[str]]]: - """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. - - We trust `folder_paths.folder_names_and_paths` and include a category if - *any* of its base paths lies under the Comfy `models_dir`. - """ - targets: list[tuple[str, list[str]]] = [] - models_root = os.path.abspath(folder_paths.models_dir) - for name, values in folder_paths.folder_names_and_paths.items(): - paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI - if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): - targets.append((name, paths)) - return targets - -def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: - """Validates and maps tags -> (base_dir, subdirs_for_fs)""" - root = tags[0] - if root == "models": - if len(tags) < 2: - raise ValueError("at least two tags required for model asset") - try: - bases = folder_paths.folder_names_and_paths[tags[1]][0] - except KeyError: - raise ValueError(f"unknown model category '{tags[1]}'") - if not bases: - raise ValueError(f"no base path configured for category '{tags[1]}'") - base_dir = os.path.abspath(bases[0]) - raw_subdirs = tags[2:] - else: - base_dir = os.path.abspath( - folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() - ) - raw_subdirs = tags[1:] - for i in raw_subdirs: - if i in (".", ".."): - raise ValueError("invalid path component in tags") - - return base_dir, raw_subdirs if raw_subdirs else [] - -def ensure_within_base(candidate: str, base: str) -> None: - cand_abs = os.path.abspath(candidate) - base_abs = os.path.abspath(base) - try: - if os.path.commonpath([cand_abs, base_abs]) != base_abs: - raise ValueError("destination escapes base directory") - except Exception: - raise ValueError("invalid destination path") - -def compute_relative_filename(file_path: str) -> str | None: - """ - Return the model's path relative to the last well-known folder (the model category), - using forward slashes, eg: - /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" - /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" - - For non-model paths, returns None. - NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. - """ - try: - root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) - except ValueError: - return None - - p = Path(rel_path) - parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] - if not parts: - return None - - if root_category == "models": - # parts[0] is the category ("checkpoints", "vae", etc) – drop it - inside = parts[1:] if len(parts) > 1 else [parts[0]] - return "/".join(inside) - return "/".join(parts) # input/output: keep all parts - -def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: - """Given an absolute or relative file path, determine which root category the path belongs to: - - 'input' if the file resides under `folder_paths.get_input_directory()` - - 'output' if the file resides under `folder_paths.get_output_directory()` - - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` - - Returns: - (root_category, relative_path_inside_that_root) - For 'models', the relative path is prefixed with the category name: - e.g. ('models', 'vae/test/sub/ae.safetensors') - - Raises: - ValueError: if the path does not belong to input, output, or configured model bases. - """ - fp_abs = os.path.abspath(file_path) - - def _is_within(child: str, parent: str) -> bool: - try: - return os.path.commonpath([child, parent]) == parent - except Exception: - return False - - def _rel(child: str, parent: str) -> str: - return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) - - # 1) input - input_base = os.path.abspath(folder_paths.get_input_directory()) - if _is_within(fp_abs, input_base): - return "input", _rel(fp_abs, input_base) - - # 2) output - output_base = os.path.abspath(folder_paths.get_output_directory()) - if _is_within(fp_abs, output_base): - return "output", _rel(fp_abs, output_base) - - # 3) models (check deepest matching base to avoid ambiguity) - best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) - for bucket, bases in get_comfy_models_folders(): - for b in bases: - base_abs = os.path.abspath(b) - if not _is_within(fp_abs, base_abs): - continue - cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) - if best is None or cand[0] > best[0]: - best = cand - - if best is not None: - _, bucket, rel_inside = best - combined = os.path.join(bucket, rel_inside) - return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) - - raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") - -def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: - """Return a tuple (name, tags) derived from a filesystem path. - - Semantics: - - Root category is determined by `get_relative_to_root_category_path_of_asset`. - - The returned `name` is the base filename with extension from the relative path. - - The returned `tags` are: - [root_category] + parent folders of the relative path (in order) - For 'models', this means: - file '/.../ModelsDir/vae/test_tag/ae.safetensors' - -> root_category='models', some_path='vae/test_tag/ae.safetensors' - -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] - - Raises: - ValueError: if the path does not belong to input, output, or configured model bases. - """ - root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) - p = Path(some_path) - parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] - return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) def normalize_tags(tags: list[str] | None) -> list[str]: """ @@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]: - Stripping whitespace and converting to lowercase. - Removing duplicates. """ - return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) -def collect_models_files() -> list[str]: - out: list[str] = [] - for folder_name, bases in get_comfy_models_folders(): - rel_files = folder_paths.get_filename_list(folder_name) or [] - for rel_path in rel_files: - abs_path = folder_paths.get_full_path(folder_name, rel_path) - if not abs_path: - continue - abs_path = os.path.abspath(abs_path) - allowed = False - for b in bases: - base_abs = os.path.abspath(b) - with contextlib.suppress(Exception): - if os.path.commonpath([abs_path, base_abs]) == base_abs: - allowed = True - break - if allowed: - out.append(abs_path) - return out -def is_scalar(v): - if v is None: - return True - if isinstance(v, bool): - return True - if isinstance(v, (int, float, Decimal, str)): - return True - return False +def validate_blake3_hash(s: str) -> str: + """Validate and normalize a blake3 hash string. -def project_kv(key: str, value): + Returns canonical 'blake3:' or raises ValueError. """ - Turn a metadata key/value into typed projection rows. - Returns list[dict] with keys: - key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) - """ - rows: list[dict] = [] - - def _null_row(ordinal: int) -> dict: - return { - "key": key, "ordinal": ordinal, - "val_str": None, "val_num": None, "val_bool": None, "val_json": None - } - - if value is None: - rows.append(_null_row(0)) - return rows - - if is_scalar(value): - if isinstance(value, bool): - rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) - elif isinstance(value, (int, float, Decimal)): - num = value if isinstance(value, Decimal) else Decimal(str(value)) - rows.append({"key": key, "ordinal": 0, "val_num": num}) - elif isinstance(value, str): - rows.append({"key": key, "ordinal": 0, "val_str": value}) - else: - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows - - if isinstance(value, list): - if all(is_scalar(x) for x in value): - for i, x in enumerate(value): - if x is None: - rows.append(_null_row(i)) - elif isinstance(x, bool): - rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) - elif isinstance(x, (int, float, Decimal)): - num = x if isinstance(x, Decimal) else Decimal(str(x)) - rows.append({"key": key, "ordinal": i, "val_num": num}) - elif isinstance(x, str): - rows.append({"key": key, "ordinal": i, "val_str": x}) - else: - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - for i, x in enumerate(value): - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows + s = s.strip().lower() + if not s or ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if ( + algo != "blake3" + or len(digest) != 64 + or any(c for c in digest if c not in "0123456789abcdef") + ): + raise ValueError("hash must be 'blake3:'") + return f"{algo}:{digest}" diff --git a/app/assets/manager.py b/app/assets/manager.py deleted file mode 100644 index a68c8c8ae..000000000 --- a/app/assets/manager.py +++ /dev/null @@ -1,516 +0,0 @@ -import os -import mimetypes -import contextlib -from typing import Sequence - -from app.database.db import create_session -from app.assets.api import schemas_out, schemas_in -from app.assets.database.queries import ( - asset_exists_by_hash, - asset_info_exists_for_asset_id, - get_asset_by_hash, - get_asset_info_by_id, - fetch_asset_info_asset_and_tags, - fetch_asset_info_and_asset, - create_asset_info_for_existing_asset, - touch_asset_info_by_id, - update_asset_info_full, - delete_asset_info_by_id, - list_cache_states_by_asset_id, - list_asset_infos_page, - list_tags_with_usage, - get_asset_tags, - add_tags_to_asset_info, - remove_tags_from_asset_info, - pick_best_live_path, - ingest_fs_asset, - set_asset_info_preview, -) -from app.assets.helpers import resolve_destination_from_tags, ensure_within_base -from app.assets.database.models import Asset - - -def _safe_sort_field(requested: str | None) -> str: - if not requested: - return "created_at" - v = requested.lower() - if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: - return v - return "created_at" - - -def _get_size_mtime_ns(path: str) -> tuple[int, int]: - st = os.stat(path, follow_symlinks=True) - return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - - -def _safe_filename(name: str | None, fallback: str) -> str: - n = os.path.basename((name or "").strip() or fallback) - if n: - return n - return fallback - - -def asset_exists(*, asset_hash: str) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - with create_session() as session: - return asset_exists_by_hash(session, asset_hash=asset_hash) - - -def list_assets( - *, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, - name_contains: str | None = None, - metadata_filter: dict | None = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", - owner_id: str = "", -) -> schemas_out.AssetsList: - sort = _safe_sort_field(sort) - order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() - - with create_session() as session: - infos, tag_map, total = list_asset_infos_page( - session, - owner_id=owner_id, - include_tags=include_tags, - exclude_tags=exclude_tags, - name_contains=name_contains, - metadata_filter=metadata_filter, - limit=limit, - offset=offset, - sort=sort, - order=order, - ) - - summaries: list[schemas_out.AssetSummary] = [] - for info in infos: - asset = info.asset - tags = tag_map.get(info.id, []) - summaries.append( - schemas_out.AssetSummary( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset else None, - mime_type=asset.mime_type if asset else None, - tags=tags, - created_at=info.created_at, - updated_at=info.updated_at, - last_access_time=info.last_access_time, - ) - ) - - return schemas_out.AssetsList( - assets=summaries, - total=total, - has_more=(offset + len(summaries)) < total, - ) - - -def get_asset( - *, - asset_info_id: str, - owner_id: str = "", -) -> schemas_out.AssetDetail: - with create_session() as session: - res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not res: - raise ValueError(f"AssetInfo {asset_info_id} not found") - info, asset, tag_names = res - preview_id = info.preview_id - - return schemas_out.AssetDetail( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, - mime_type=asset.mime_type if asset else None, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - ) - - -def resolve_asset_content_for_download( - *, - asset_info_id: str, - owner_id: str = "", -) -> tuple[str, str, str]: - with create_session() as session: - pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not pair: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info, asset = pair - states = list_cache_states_by_asset_id(session, asset_id=asset.id) - abs_path = pick_best_live_path(states) - if not abs_path: - raise FileNotFoundError - - touch_asset_info_by_id(session, asset_info_id=asset_info_id) - session.commit() - - ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" - download_name = info.name or os.path.basename(abs_path) - return abs_path, ctype, download_name - - -def upload_asset_from_temp_path( - spec: schemas_in.UploadAssetSpec, - *, - temp_path: str, - client_filename: str | None = None, - owner_id: str = "", - expected_asset_hash: str | None = None, -) -> schemas_out.AssetCreated: - """ - Create new asset or update existing asset from a temporary file path. - """ - try: - # NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment - import app.assets.hashing as hashing - digest = hashing.blake3_hash(temp_path) - except Exception as e: - raise RuntimeError(f"failed to hash uploaded file: {e}") - asset_hash = "blake3:" + digest - - if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): - raise ValueError("HASH_MISMATCH") - - with create_session() as session: - existing = get_asset_by_hash(session, asset_hash=asset_hash) - if existing is not None: - with contextlib.suppress(Exception): - if temp_path and os.path.exists(temp_path): - os.remove(temp_path) - - display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) - info = create_asset_info_for_existing_asset( - session, - asset_hash=asset_hash, - name=display_name, - user_metadata=spec.user_metadata or {}, - tags=spec.tags or [], - tag_origin="manual", - owner_id=owner_id, - ) - tag_names = get_asset_tags(session, asset_info_id=info.id) - session.commit() - - return schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=existing.hash, - size=int(existing.size_bytes) if existing.size_bytes is not None else None, - mime_type=existing.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=False, - ) - - base_dir, subdirs = resolve_destination_from_tags(spec.tags) - dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir - os.makedirs(dest_dir, exist_ok=True) - - src_for_ext = (client_filename or spec.name or "").strip() - _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" - ext = _ext if 0 < len(_ext) <= 16 else "" - hashed_basename = f"{digest}{ext}" - dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) - ensure_within_base(dest_abs, base_dir) - - content_type = ( - mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] - or mimetypes.guess_type(hashed_basename, strict=False)[0] - or "application/octet-stream" - ) - - try: - os.replace(temp_path, dest_abs) - except Exception as e: - raise RuntimeError(f"failed to move uploaded file into place: {e}") - - try: - size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) - except OSError as e: - raise RuntimeError(f"failed to stat destination file: {e}") - - with create_session() as session: - result = ingest_fs_asset( - session, - asset_hash=asset_hash, - abs_path=dest_abs, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - mime_type=content_type, - info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), - owner_id=owner_id, - preview_id=None, - user_metadata=spec.user_metadata or {}, - tags=spec.tags, - tag_origin="manual", - require_existing_tags=False, - ) - info_id = result["asset_info_id"] - if not info_id: - raise RuntimeError("failed to create asset metadata") - - pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) - if not pair: - raise RuntimeError("inconsistent DB state after ingest") - info, asset = pair - tag_names = get_asset_tags(session, asset_info_id=info.id) - created_result = schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=asset.hash, - size=int(asset.size_bytes), - mime_type=asset.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=result["asset_created"], - ) - session.commit() - - return created_result - - -def update_asset( - *, - asset_info_id: str, - name: str | None = None, - tags: list[str] | None = None, - user_metadata: dict | None = None, - owner_id: str = "", -) -> schemas_out.AssetUpdated: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - info = update_asset_info_full( - session, - asset_info_id=asset_info_id, - name=name, - tags=tags, - user_metadata=user_metadata, - tag_origin="manual", - asset_info_row=info_row, - ) - - tag_names = get_asset_tags(session, asset_info_id=asset_info_id) - result = schemas_out.AssetUpdated( - id=info.id, - name=info.name, - asset_hash=info.asset.hash if info.asset else None, - tags=tag_names, - user_metadata=info.user_metadata or {}, - updated_at=info.updated_at, - ) - session.commit() - - return result - - -def set_asset_preview( - *, - asset_info_id: str, - preview_asset_id: str | None = None, - owner_id: str = "", -) -> schemas_out.AssetDetail: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - set_asset_info_preview( - session, - asset_info_id=asset_info_id, - preview_asset_id=preview_asset_id, - ) - - res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not res: - raise RuntimeError("State changed during preview update") - info, asset, tags = res - result = schemas_out.AssetDetail( - id=info.id, - name=info.name, - asset_hash=asset.hash if asset else None, - size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, - mime_type=asset.mime_type if asset else None, - tags=tags, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - ) - session.commit() - - return result - - -def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - asset_id = info_row.asset_id if info_row else None - deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not deleted: - session.commit() - return False - - if not delete_content_if_orphan or not asset_id: - session.commit() - return True - - still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id) - if still_exists: - session.commit() - return True - - states = list_cache_states_by_asset_id(session, asset_id=asset_id) - file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] - - asset_row = session.get(Asset, asset_id) - if asset_row is not None: - session.delete(asset_row) - - session.commit() - for p in file_paths: - with contextlib.suppress(Exception): - if p and os.path.isfile(p): - os.remove(p) - return True - - -def create_asset_from_hash( - *, - hash_str: str, - name: str, - tags: list[str] | None = None, - user_metadata: dict | None = None, - owner_id: str = "", -) -> schemas_out.AssetCreated | None: - canonical = hash_str.strip().lower() - with create_session() as session: - asset = get_asset_by_hash(session, asset_hash=canonical) - if not asset: - return None - - info = create_asset_info_for_existing_asset( - session, - asset_hash=canonical, - name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), - user_metadata=user_metadata or {}, - tags=tags or [], - tag_origin="manual", - owner_id=owner_id, - ) - tag_names = get_asset_tags(session, asset_info_id=info.id) - result = schemas_out.AssetCreated( - id=info.id, - name=info.name, - asset_hash=asset.hash, - size=int(asset.size_bytes), - mime_type=asset.mime_type, - tags=tag_names, - user_metadata=info.user_metadata or {}, - preview_id=info.preview_id, - created_at=info.created_at, - last_access_time=info.last_access_time, - created_new=False, - ) - session.commit() - - return result - - -def add_tags_to_asset( - *, - asset_info_id: str, - tags: list[str], - origin: str = "manual", - owner_id: str = "", -) -> schemas_out.TagsAdd: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - data = add_tags_to_asset_info( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=origin, - create_if_missing=True, - asset_info_row=info_row, - ) - session.commit() - return schemas_out.TagsAdd(**data) - - -def remove_tags_from_asset( - *, - asset_info_id: str, - tags: list[str], - owner_id: str = "", -) -> schemas_out.TagsRemove: - with create_session() as session: - info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) - if not info_row: - raise ValueError(f"AssetInfo {asset_info_id} not found") - if info_row.owner_id and info_row.owner_id != owner_id: - raise PermissionError("not owner") - - data = remove_tags_from_asset_info( - session, - asset_info_id=asset_info_id, - tags=tags, - ) - session.commit() - return schemas_out.TagsRemove(**data) - - -def list_tags( - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - order: str = "count_desc", - include_zero: bool = True, - owner_id: str = "", -) -> schemas_out.TagsList: - limit = max(1, min(1000, limit)) - offset = max(0, offset) - - with create_session() as session: - rows, total = list_tags_with_usage( - session, - prefix=prefix, - limit=limit, - offset=offset, - include_zero=include_zero, - order=order, - owner_id=owner_id, - ) - - tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] - return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 0172a5c2f..e27ea5123 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -1,263 +1,567 @@ -import contextlib -import time import logging import os -import sqlalchemy +from pathlib import Path +from typing import Callable, Literal, TypedDict import folder_paths -from app.database.db import create_session, dependencies_available -from app.assets.helpers import ( - collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, - list_tree,prefixes_for_root, escape_like_prefix, - RootType +from app.assets.database.queries import ( + add_missing_tag_for_asset_id, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + delete_orphaned_seed_asset, + delete_references_by_ids, + ensure_tags_exist, + get_asset_by_hash, + get_references_for_prefixes, + get_unenriched_references, + mark_references_missing_outside_prefixes, + reassign_asset_references, + remove_missing_tag_for_asset_id, + set_reference_metadata, + update_asset_hash_and_mime, ) -from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id -from app.assets.database.bulk_ops import seed_from_paths_batch -from app.assets.database.models import Asset, AssetCacheState, AssetInfo +from app.assets.services.bulk_ingest import ( + SeedAssetSpec, + batch_insert_seed_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + is_visible, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash +from app.assets.services.metadata_extract import extract_file_metadata +from app.assets.services.path_utils import ( + compute_relative_filename, + get_comfy_models_folders, + get_name_and_tags_from_asset_path, +) +from app.database.db import create_session -def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None: - """ - Scan the given roots and seed the assets into the database. - """ - if not dependencies_available(): - if enable_logging: - logging.warning("Database dependencies not available, skipping assets scan") - return - t_start = time.perf_counter() - created = 0 - skipped_existing = 0 - orphans_pruned = 0 - paths: list[str] = [] - try: - existing_paths: set[str] = set() - for r in roots: - try: - survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) - if survivors: - existing_paths.update(survivors) - except Exception as e: - logging.exception("fast DB scan failed for %s: %s", r, e) +class _RefInfo(TypedDict): + ref_id: str + file_path: str + exists: bool + stat_unchanged: bool + needs_verify: bool - try: - orphans_pruned = _prune_orphaned_assets(roots) - except Exception as e: - logging.exception("orphan pruning failed: %s", e) - if "models" in roots: - paths.extend(collect_models_files()) - if "input" in roots: - paths.extend(list_tree(folder_paths.get_input_directory())) - if "output" in roots: - paths.extend(list_tree(folder_paths.get_output_directory())) +class _AssetAccumulator(TypedDict): + hash: str | None + size_db: int + refs: list[_RefInfo] - specs: list[dict] = [] - tag_pool: set[str] = set() - for p in paths: - abs_p = os.path.abspath(p) - if abs_p in existing_paths: - skipped_existing += 1 + +RootType = Literal["models", "input", "output"] + + +def get_prefixes_for_root(root: RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + + +def get_all_known_prefixes() -> list[str]: + """Get all known asset prefixes across all root types.""" + all_roots: tuple[RootType, ...] = ("models", "input", "output") + return [p for root in all_roots for p in get_prefixes_for_root(root)] + + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + if not all(is_visible(part) for part in Path(rel_path).parts): continue - try: - stat_p = os.stat(abs_p, follow_symlinks=False) - except OSError: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: continue - # skip empty files - if not stat_p.st_size: - continue - name, tags = get_name_and_tags_from_asset_path(abs_p) - specs.append( - { - "abs_path": abs_p, - "size_bytes": stat_p.st_size, - "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)), - "info_name": name, - "tags": tags, - "fname": compute_relative_filename(abs_p), - } - ) - for t in tags: - tag_pool.add(t) - # if no file specs, nothing to do - if not specs: - return - with create_session() as sess: - if tag_pool: - ensure_tags_exist(sess, tag_pool, tag_type="user") - - result = seed_from_paths_batch(sess, specs=specs, owner_id="") - created += result["inserted_infos"] - sess.commit() - finally: - if enable_logging: - logging.info( - "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)", - roots, - time.perf_counter() - t_start, - created, - skipped_existing, - orphans_pruned, - len(paths), - ) + abs_path = os.path.abspath(abs_path) + allowed = False + abs_p = Path(abs_path) + for b in bases: + if abs_p.is_relative_to(os.path.abspath(b)): + allowed = True + break + if allowed: + out.append(abs_path) + return out -def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: - """Prune cache states outside configured prefixes, then delete orphaned seed assets.""" - all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] - if not all_prefixes: - return 0 - - def make_prefix_condition(prefix: str): - base = prefix if prefix.endswith(os.sep) else prefix + os.sep - escaped, esc = escape_like_prefix(base) - return AssetCacheState.file_path.like(escaped + "%", escape=esc) - - matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes]) - - orphan_subq = ( - sqlalchemy.select(Asset.id) - .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) - .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) - ).scalar_subquery() - - with create_session() as sess: - sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix)) - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) - result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq))) - sess.commit() - return result.rowcount - - -def _fast_db_consistency_pass( +def sync_references_with_filesystem( + session, root: RootType, - *, collect_existing_paths: bool = False, update_missing_tags: bool = False, ) -> set[str] | None: - """Fast DB+FS pass for a root: - - Toggle needs_verify per state using fast check - - For hashed assets with at least one fast-ok state in this root: delete stale missing states - - For seed assets with all states missing: delete Asset and its AssetInfos - - Optionally add/remove 'missing' tags based on fast-ok in this root - - Optionally return surviving absolute paths + """Reconcile asset references with filesystem for a root. + + - Toggle needs_verify per reference using mtime/size stat check + - For hashed assets with at least one stat-unchanged ref: delete stale missing refs + - For seed assets with all refs missing: delete Asset and its references + - Optionally add/remove 'missing' tags based on stat check in this root + - Optionally return surviving absolute paths + + Args: + session: Database session + root: Root type to scan + collect_existing_paths: If True, return set of surviving file paths + update_missing_tags: If True, update 'missing' tags based on file status + + Returns: + Set of surviving absolute paths if collect_existing_paths=True, else None """ - prefixes = prefixes_for_root(root) + prefixes = get_prefixes_for_root(root) if not prefixes: return set() if collect_existing_paths else None - conds = [] - for p in prefixes: - base = os.path.abspath(p) - if not base.endswith(os.sep): - base += os.sep - escaped, esc = escape_like_prefix(base) - conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + rows = get_references_for_prefixes( + session, prefixes, include_missing=update_missing_tags + ) + + by_asset: dict[str, _AssetAccumulator] = {} + for row in rows: + acc = by_asset.get(row.asset_id) + if acc is None: + acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []} + by_asset[row.asset_id] = acc + + stat_unchanged = False + try: + exists = True + stat_unchanged = verify_file_unchanged( + mtime_db=row.mtime_ns, + size_db=acc["size_db"], + stat_result=os.stat(row.file_path, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except PermissionError: + exists = True + logging.debug("Permission denied accessing %s", row.file_path) + except OSError as e: + exists = False + logging.debug("OSError checking %s: %s", row.file_path, e) + + acc["refs"].append( + { + "ref_id": row.reference_id, + "file_path": row.file_path, + "exists": exists, + "stat_unchanged": stat_unchanged, + "needs_verify": row.needs_verify, + } + ) + + to_set_verify: list[str] = [] + to_clear_verify: list[str] = [] + stale_ref_ids: list[str] = [] + to_mark_missing: list[str] = [] + to_clear_missing: list[str] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + refs = acc["refs"] + any_unchanged = any(r["stat_unchanged"] for r in refs) + all_missing = all(not r["exists"] for r in refs) + + for r in refs: + if not r["exists"]: + to_mark_missing.append(r["ref_id"]) + continue + if r["stat_unchanged"]: + to_clear_missing.append(r["ref_id"]) + if r["needs_verify"]: + to_clear_verify.append(r["ref_id"]) + if not r["stat_unchanged"] and not r["needs_verify"]: + to_set_verify.append(r["ref_id"]) + + if a_hash is None: + if refs and all_missing: + delete_orphaned_seed_asset(session, aid) + else: + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + continue + + if any_unchanged: + for r in refs: + if not r["exists"]: + stale_ref_ids.append(r["ref_id"]) + if update_missing_tags: + try: + remove_missing_tag_for_asset_id(session, asset_id=aid) + except Exception as e: + logging.warning( + "Failed to remove missing tag for asset %s: %s", aid, e + ) + elif update_missing_tags: + try: + add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic") + except Exception as e: + logging.warning("Failed to add missing tag for asset %s: %s", aid, e) + + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + + delete_references_by_ids(session, stale_ref_ids) + stale_set = set(stale_ref_ids) + to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set] + bulk_update_is_missing(session, to_mark_missing, value=True) + bulk_update_is_missing(session, to_clear_missing, value=False) + bulk_update_needs_verify(session, to_set_verify, value=True) + bulk_update_needs_verify(session, to_clear_verify, value=False) + + return survivors if collect_existing_paths else None + + +def sync_root_safely(root: RootType) -> set[str]: + """Sync a single root's references with the filesystem. + + Returns survivors (existing paths) or empty set on failure. + """ + try: + with create_session() as sess: + survivors = sync_references_with_filesystem( + sess, + root, + collect_existing_paths=True, + update_missing_tags=True, + ) + sess.commit() + return survivors or set() + except Exception as e: + logging.exception("fast DB scan failed for %s: %s", root, e) + return set() + + +def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int: + """Mark references as missing when outside the given prefixes. + + This is a non-destructive soft-delete. Returns count marked or 0 on failure. + """ + try: + with create_session() as sess: + count = mark_references_missing_outside_prefixes(sess, prefixes) + sess.commit() + return count + except Exception as e: + logging.exception("marking missing assets failed: %s", e) + return 0 + + +def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: + """Collect all file paths for the given roots.""" + paths: list[str] = [] + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_files_recursively(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_files_recursively(folder_paths.get_output_directory())) + return paths + + +def build_asset_specs( + paths: list[str], + existing_paths: set[str], + enable_metadata_extraction: bool = True, + compute_hashes: bool = False, +) -> tuple[list[SeedAssetSpec], set[str], int]: + """Build asset specs from paths, returning (specs, tag_pool, skipped_count). + + Args: + paths: List of file paths to process + existing_paths: Set of paths that already exist in the database + enable_metadata_extraction: If True, extract tier 1 & 2 metadata + compute_hashes: If True, compute blake3 hashes (slow for large files) + """ + specs: list[SeedAssetSpec] = [] + tag_pool: set[str] = set() + skipped = 0 + + for p in paths: + abs_p = os.path.abspath(p) + if abs_p in existing_paths: + skipped += 1 + continue + try: + stat_p = os.stat(abs_p, follow_symlinks=True) + except OSError: + continue + if not stat_p.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(abs_p) + rel_fname = compute_relative_filename(abs_p) + + # Extract metadata (tier 1: filesystem, tier 2: safetensors header) + metadata = None + if enable_metadata_extraction: + metadata = extract_file_metadata( + abs_p, + stat_result=stat_p, + relative_filename=rel_fname, + ) + + # Compute hash if requested + asset_hash: str | None = None + if compute_hashes: + try: + digest, _ = compute_blake3_hash(abs_p) + asset_hash = "blake3:" + digest + except Exception as e: + logging.warning("Failed to hash %s: %s", abs_p, e) + + mime_type = metadata.content_type if metadata else None + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": get_mtime_ns(stat_p), + "info_name": name, + "tags": tags, + "fname": rel_fname, + "metadata": metadata, + "hash": asset_hash, + "mime_type": mime_type, + } + ) + tag_pool.update(tags) + + return specs, tag_pool, skipped + + + +def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: + """Insert asset specs into database, returning count of created refs.""" + if not specs: + return 0 + with create_session() as sess: + if tag_pool: + ensure_tags_exist(sess, tag_pool, tag_type="user") + result = batch_insert_seed_assets(sess, specs=specs, owner_id="") + sess.commit() + return result.inserted_refs + + +# Enrichment level constants +ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only +ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type) +ENRICHMENT_HASHED = 2 # Hash computed (blake3) + + +def get_unenriched_assets_for_roots( + roots: tuple[RootType, ...], + max_level: int = ENRICHMENT_STUB, + limit: int = 1000, +) -> list: + """Get assets that need enrichment for the given roots. + + Args: + roots: Tuple of root types to scan + max_level: Maximum enrichment level to include + limit: Maximum number of rows to return + + Returns: + List of UnenrichedReferenceRow + """ + prefixes: list[str] = [] + for root in roots: + prefixes.extend(get_prefixes_for_root(root)) + + if not prefixes: + return [] with create_session() as sess: - rows = ( - sess.execute( - sqlalchemy.select( - AssetCacheState.id, - AssetCacheState.file_path, - AssetCacheState.mtime_ns, - AssetCacheState.needs_verify, - AssetCacheState.asset_id, - Asset.hash, - Asset.size_bytes, - ) - .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(sqlalchemy.or_(*conds)) - .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + return get_unenriched_references( + sess, prefixes, max_level=max_level, limit=limit + ) + + +def enrich_asset( + session, + file_path: str, + reference_id: str, + asset_id: str, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> int: + """Enrich a single asset with metadata and/or hash. + + Args: + session: Database session (caller manages lifecycle) + file_path: Absolute path to the file + reference_id: ID of the reference to update + asset_id: ID of the asset to update (for mime_type and hash) + extract_metadata: If True, extract safetensors header and mime type + compute_hash: If True, compute blake3 hash + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + New enrichment level achieved + """ + new_level = ENRICHMENT_STUB + + try: + stat_p = os.stat(file_path, follow_symlinks=True) + except OSError: + return new_level + + rel_fname = compute_relative_filename(file_path) + mime_type: str | None = None + metadata = None + + if extract_metadata: + metadata = extract_file_metadata( + file_path, + stat_result=stat_p, + relative_filename=rel_fname, + ) + if metadata: + mime_type = metadata.content_type + new_level = ENRICHMENT_METADATA + + full_hash: str | None = None + if compute_hash: + try: + mtime_before = get_mtime_ns(stat_p) + size_before = stat_p.st_size + + # Restore checkpoint if available and file unchanged + checkpoint = None + if hash_checkpoints is not None: + checkpoint = hash_checkpoints.get(file_path) + if checkpoint is not None: + cur_stat = os.stat(file_path, follow_symlinks=True) + if (checkpoint.mtime_ns != get_mtime_ns(cur_stat) + or checkpoint.file_size != cur_stat.st_size): + checkpoint = None + hash_checkpoints.pop(file_path, None) + else: + mtime_before = get_mtime_ns(cur_stat) + + digest, new_checkpoint = compute_blake3_hash( + file_path, + interrupt_check=interrupt_check, + checkpoint=checkpoint, ) - ).all() - by_asset: dict[str, dict] = {} - for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: - acc = by_asset.get(aid) - if acc is None: - acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} - by_asset[aid] = acc + if digest is None: + # Interrupted — save checkpoint for later resumption + if hash_checkpoints is not None and new_checkpoint is not None: + new_checkpoint.mtime_ns = mtime_before + new_checkpoint.file_size = size_before + hash_checkpoints[file_path] = new_checkpoint + return new_level + + # Completed — clear any saved checkpoint + if hash_checkpoints is not None: + hash_checkpoints.pop(file_path, None) + + stat_after = os.stat(file_path, follow_symlinks=True) + mtime_after = get_mtime_ns(stat_after) + if mtime_before != mtime_after: + logging.warning("File modified during hashing, discarding hash: %s", file_path) + else: + full_hash = f"blake3:{digest}" + metadata_ok = not extract_metadata or metadata is not None + if metadata_ok: + new_level = ENRICHMENT_HASHED + except Exception as e: + logging.warning("Failed to hash %s: %s", file_path, e) + + if extract_metadata and metadata: + user_metadata = metadata.to_user_metadata() + set_reference_metadata(session, reference_id, user_metadata) + + if full_hash: + existing = get_asset_by_hash(session, full_hash) + if existing and existing.id != asset_id: + reassign_asset_references(session, asset_id, existing.id, reference_id) + delete_orphaned_seed_asset(session, asset_id) + if mime_type: + update_asset_hash_and_mime(session, existing.id, mime_type=mime_type) + else: + update_asset_hash_and_mime(session, asset_id, full_hash, mime_type) + elif mime_type: + update_asset_hash_and_mime(session, asset_id, mime_type=mime_type) + + bulk_update_enrichment_level(session, [reference_id], new_level) + session.commit() + + return new_level + + +def enrich_assets_batch( + rows: list, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> tuple[int, list[str]]: + """Enrich a batch of assets. + + Uses a single DB session for the entire batch, committing after each + individual asset to avoid long-held transactions while eliminating + per-asset session creation overhead. + + Args: + rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots + extract_metadata: If True, extract metadata for each asset + compute_hash: If True, compute hash for each asset + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + Tuple of (enriched_count, failed_reference_ids) + """ + enriched = 0 + failed_ids: list[str] = [] + + with create_session() as sess: + for row in rows: + if interrupt_check is not None and interrupt_check(): + break - fast_ok = False try: - exists = True - fast_ok = fast_asset_file_check( - mtime_db=mtime_db, - size_db=acc["size_db"], - stat_result=os.stat(fp, follow_symlinks=True), + new_level = enrich_asset( + sess, + file_path=row.file_path, + reference_id=row.reference_id, + asset_id=row.asset_id, + extract_metadata=extract_metadata, + compute_hash=compute_hash, + interrupt_check=interrupt_check, + hash_checkpoints=hash_checkpoints, ) - except FileNotFoundError: - exists = False - except OSError: - exists = False - - acc["states"].append({ - "sid": sid, - "fp": fp, - "exists": exists, - "fast_ok": fast_ok, - "needs_verify": bool(needs_verify), - }) - - to_set_verify: list[int] = [] - to_clear_verify: list[int] = [] - stale_state_ids: list[int] = [] - survivors: set[str] = set() - - for aid, acc in by_asset.items(): - a_hash = acc["hash"] - states = acc["states"] - any_fast_ok = any(s["fast_ok"] for s in states) - all_missing = all(not s["exists"] for s in states) - - for s in states: - if not s["exists"]: - continue - if s["fast_ok"] and s["needs_verify"]: - to_clear_verify.append(s["sid"]) - if not s["fast_ok"] and not s["needs_verify"]: - to_set_verify.append(s["sid"]) - - if a_hash is None: - if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) - asset = sess.get(Asset, aid) - if asset: - sess.delete(asset) + if new_level > row.enrichment_level: + enriched += 1 else: - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - continue + failed_ids.append(row.reference_id) + except Exception as e: + logging.warning("Failed to enrich %s: %s", row.file_path, e) + sess.rollback() + failed_ids.append(row.reference_id) - if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records - for s in states: - if not s["exists"]: - stale_state_ids.append(s["sid"]) - if update_missing_tags: - with contextlib.suppress(Exception): - remove_missing_tag_for_asset_id(sess, asset_id=aid) - elif update_missing_tags: - with contextlib.suppress(Exception): - add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") - - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - - if stale_state_ids: - sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) - if to_set_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_set_verify)) - .values(needs_verify=True) - ) - if to_clear_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_clear_verify)) - .values(needs_verify=False) - ) - sess.commit() - return survivors if collect_existing_paths else None + return enriched, failed_ids diff --git a/app/assets/seeder.py b/app/assets/seeder.py new file mode 100644 index 000000000..029448464 --- /dev/null +++ b/app/assets/seeder.py @@ -0,0 +1,794 @@ +"""Background asset seeder with thread management and cancellation support.""" + +import logging +import os +import threading +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Callable + +from app.assets.scanner import ( + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + RootType, + build_asset_specs, + collect_paths_for_roots, + enrich_assets_batch, + get_all_known_prefixes, + get_prefixes_for_root, + get_unenriched_assets_for_roots, + insert_asset_specs, + mark_missing_outside_prefixes_safely, + sync_root_safely, +) +from app.database.db import dependencies_available + + +class ScanInProgressError(Exception): + """Raised when an operation cannot proceed because a scan is running.""" + + +class State(Enum): + """Seeder state machine states.""" + + IDLE = "IDLE" + RUNNING = "RUNNING" + PAUSED = "PAUSED" + CANCELLING = "CANCELLING" + + +class ScanPhase(Enum): + """Scan phase options.""" + + FAST = "fast" # Phase 1: filesystem only (stubs) + ENRICH = "enrich" # Phase 2: metadata + hash + FULL = "full" # Both phases sequentially + + +@dataclass +class Progress: + """Progress information for a scan operation.""" + + scanned: int = 0 + total: int = 0 + created: int = 0 + skipped: int = 0 + + +@dataclass +class ScanStatus: + """Current status of the asset seeder.""" + + state: State + progress: Progress | None + errors: list[str] = field(default_factory=list) + + +ProgressCallback = Callable[[Progress], None] + + +class _AssetSeeder: + """Background asset scanning manager. + + Spawns ephemeral daemon threads for scanning. + Each scan creates a new thread that exits when complete. + Use the module-level ``asset_seeder`` instance. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._state = State.IDLE + self._progress: Progress | None = None + self._last_progress: Progress | None = None + self._errors: list[str] = [] + self._thread: threading.Thread | None = None + self._cancel_event = threading.Event() + self._run_gate = threading.Event() + self._run_gate.set() # Start unpaused (set = running, clear = paused) + self._roots: tuple[RootType, ...] = () + self._phase: ScanPhase = ScanPhase.FULL + self._compute_hashes: bool = False + self._prune_first: bool = False + self._progress_callback: ProgressCallback | None = None + self._disabled: bool = False + + def disable(self) -> None: + """Disable the asset seeder, preventing any scans from starting.""" + self._disabled = True + logging.info("Asset seeder disabled") + + def is_disabled(self) -> bool: + """Check if the asset seeder is disabled.""" + return self._disabled + + def start( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + phase: ScanPhase = ScanPhase.FULL, + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + compute_hashes: bool = False, + ) -> bool: + """Start a background scan for the given roots. + + Args: + roots: Tuple of root types to scan (models, input, output) + phase: Scan phase to run (FAST, ENRICH, or FULL for both) + progress_callback: Optional callback called with progress updates + prune_first: If True, prune orphaned assets before scanning + compute_hashes: If True, compute blake3 hashes (slow) + + Returns: + True if scan was started, False if already running + """ + if self._disabled: + logging.debug("Asset seeder is disabled, skipping start") + return False + logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value) + with self._lock: + if self._state != State.IDLE: + logging.info("Asset seeder already running, skipping start") + return False + self._state = State.RUNNING + self._progress = Progress() + self._errors = [] + self._roots = roots + self._phase = phase + self._prune_first = prune_first + self._compute_hashes = compute_hashes + self._progress_callback = progress_callback + self._cancel_event.clear() + self._run_gate.set() # Ensure unpaused when starting + self._thread = threading.Thread( + target=self._run_scan, + name="_AssetSeeder", + daemon=True, + ) + self._thread.start() + return True + + def start_fast( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + ) -> bool: + """Start a fast scan (phase 1 only) - creates stub records. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + prune_first: If True, prune orphaned assets before scanning + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.FAST, + progress_callback=progress_callback, + prune_first=prune_first, + compute_hashes=False, + ) + + def start_enrich( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + compute_hashes: bool = False, + ) -> bool: + """Start an enrichment scan (phase 2 only) - extracts metadata and hashes. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + compute_hashes: If True, compute blake3 hashes + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.ENRICH, + progress_callback=progress_callback, + prune_first=False, + compute_hashes=compute_hashes, + ) + + def cancel(self) -> bool: + """Request cancellation of the current scan. + + Returns: + True if cancellation was requested, False if not running or paused + """ + with self._lock: + if self._state not in (State.RUNNING, State.PAUSED): + return False + logging.info("Asset seeder cancelling (was %s)", self._state.value) + self._state = State.CANCELLING + self._cancel_event.set() + self._run_gate.set() # Unblock if paused so thread can exit + return True + + def stop(self) -> bool: + """Stop the current scan (alias for cancel). + + Returns: + True if stop was requested, False if not running + """ + return self.cancel() + + def pause(self) -> bool: + """Pause the current scan. + + The scan will complete its current batch before pausing. + + Returns: + True if pause was requested, False if not running + """ + with self._lock: + if self._state != State.RUNNING: + return False + logging.info("Asset seeder pausing") + self._state = State.PAUSED + self._run_gate.clear() + return True + + def resume(self) -> bool: + """Resume a paused scan. + + This is a noop if the scan is not in the PAUSED state + + Returns: + True if resumed, False if not paused + """ + with self._lock: + if self._state != State.PAUSED: + return False + logging.info("Asset seeder resuming") + self._state = State.RUNNING + self._run_gate.set() + self._emit_event("assets.seed.resumed", {}) + return True + + def restart( + self, + roots: tuple[RootType, ...] | None = None, + phase: ScanPhase | None = None, + progress_callback: ProgressCallback | None = None, + prune_first: bool | None = None, + compute_hashes: bool | None = None, + timeout: float = 5.0, + ) -> bool: + """Cancel any running scan and start a new one. + + Args: + roots: Roots to scan (defaults to previous roots) + phase: Scan phase (defaults to previous phase) + progress_callback: Progress callback (defaults to previous) + prune_first: Prune before scan (defaults to previous) + compute_hashes: Compute hashes (defaults to previous) + timeout: Max seconds to wait for current scan to stop + + Returns: + True if new scan was started, False if failed to stop previous + """ + logging.info("Asset seeder restart requested") + with self._lock: + prev_roots = self._roots + prev_phase = self._phase + prev_callback = self._progress_callback + prev_prune = self._prune_first + prev_hashes = self._compute_hashes + + self.cancel() + if not self.wait(timeout=timeout): + return False + + cb = progress_callback if progress_callback is not None else prev_callback + return self.start( + roots=roots if roots is not None else prev_roots, + phase=phase if phase is not None else prev_phase, + progress_callback=cb, + prune_first=prune_first if prune_first is not None else prev_prune, + compute_hashes=( + compute_hashes if compute_hashes is not None else prev_hashes + ), + ) + + def wait(self, timeout: float | None = None) -> bool: + """Wait for the current scan to complete. + + Args: + timeout: Maximum seconds to wait, or None for no timeout + + Returns: + True if scan completed, False if timeout expired or no scan running + """ + with self._lock: + thread = self._thread + if thread is None: + return True + thread.join(timeout=timeout) + return not thread.is_alive() + + def get_status(self) -> ScanStatus: + """Get the current status and progress of the seeder.""" + with self._lock: + src = self._progress or self._last_progress + return ScanStatus( + state=self._state, + progress=Progress( + scanned=src.scanned, + total=src.total, + created=src.created, + skipped=src.skipped, + ) + if src + else None, + errors=list(self._errors), + ) + + def shutdown(self, timeout: float = 5.0) -> None: + """Gracefully shutdown: cancel any running scan and wait for thread. + + Args: + timeout: Maximum seconds to wait for thread to exit + """ + self.cancel() + self.wait(timeout=timeout) + with self._lock: + self._thread = None + + def mark_missing_outside_prefixes(self) -> int: + """Mark references as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and their + metadata are preserved, but references are flagged as missing. + They can be restored if the file reappears in a future scan. + + This operation is decoupled from scanning to prevent partial scans + from accidentally marking assets belonging to other roots. + + Should be called explicitly when cleanup is desired, typically after + a full scan of all roots or during maintenance. + + Returns: + Number of references marked as missing + + Raises: + ScanInProgressError: If a scan is currently running + """ + with self._lock: + if self._state != State.IDLE: + raise ScanInProgressError( + "Cannot mark missing assets while scan is running" + ) + self._state = State.RUNNING + + try: + if not dependencies_available(): + logging.warning( + "Database dependencies not available, skipping mark missing" + ) + return 0 + + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d references as missing", marked) + return marked + finally: + with self._lock: + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _is_cancelled(self) -> bool: + """Check if cancellation has been requested.""" + return self._cancel_event.is_set() + + def _is_paused_or_cancelled(self) -> bool: + """Non-blocking check: True if paused or cancelled. + + Use as interrupt_check for I/O-bound work (e.g. hashing) so that + file handles are released immediately on pause rather than held + open while blocked. The caller is responsible for blocking on + _check_pause_and_cancel() afterward. + """ + return not self._run_gate.is_set() or self._cancel_event.is_set() + + def _check_pause_and_cancel(self) -> bool: + """Block while paused, then check if cancelled. + + Call this at checkpoint locations in scan loops. It will: + 1. Block indefinitely while paused (until resume or cancel) + 2. Return True if cancelled, False to continue + + Returns: + True if scan should stop, False to continue + """ + if not self._run_gate.is_set(): + self._emit_event("assets.seed.paused", {}) + self._run_gate.wait() # Blocks if paused + return self._is_cancelled() + + def _emit_event(self, event_type: str, data: dict) -> None: + """Emit a WebSocket event if server is available.""" + try: + from server import PromptServer + + if hasattr(PromptServer, "instance") and PromptServer.instance: + PromptServer.instance.send_sync(event_type, data) + except Exception: + pass + + def _update_progress( + self, + scanned: int | None = None, + total: int | None = None, + created: int | None = None, + skipped: int | None = None, + ) -> None: + """Update progress counters (thread-safe).""" + callback: ProgressCallback | None = None + progress: Progress | None = None + + with self._lock: + if self._progress is None: + return + if scanned is not None: + self._progress.scanned = scanned + if total is not None: + self._progress.total = total + if created is not None: + self._progress.created = created + if skipped is not None: + self._progress.skipped = skipped + if self._progress_callback: + callback = self._progress_callback + progress = Progress( + scanned=self._progress.scanned, + total=self._progress.total, + created=self._progress.created, + skipped=self._progress.skipped, + ) + + if callback and progress: + try: + callback(progress) + except Exception: + pass + + _MAX_ERRORS = 200 + + def _add_error(self, message: str) -> None: + """Add an error message (thread-safe), capped at _MAX_ERRORS.""" + with self._lock: + if len(self._errors) < self._MAX_ERRORS: + self._errors.append(message) + + def _log_scan_config(self, roots: tuple[RootType, ...]) -> None: + """Log the directories that will be scanned.""" + import folder_paths + + for root in roots: + if root == "models": + logging.info( + "Asset scan [models] directory: %s", + os.path.abspath(folder_paths.models_dir), + ) + else: + prefixes = get_prefixes_for_root(root) + if prefixes: + logging.info("Asset scan [%s] directories: %s", root, prefixes) + + def _run_scan(self) -> None: + """Main scan loop running in background thread.""" + t_start = time.perf_counter() + roots = self._roots + phase = self._phase + cancelled = False + total_created = 0 + total_enriched = 0 + skipped_existing = 0 + total_paths = 0 + + try: + if not dependencies_available(): + self._add_error("Database dependencies not available") + self._emit_event( + "assets.seed.error", + {"message": "Database dependencies not available"}, + ) + return + + if self._prune_first: + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d refs as missing before scan", marked) + + if self._check_pause_and_cancel(): + logging.info("Asset scan cancelled after pruning phase") + cancelled = True + return + + self._log_scan_config(roots) + + # Phase 1: Fast scan (stub records) + if phase in (ScanPhase.FAST, ScanPhase.FULL): + created, skipped, paths = self._run_fast_phase(roots) + total_created, skipped_existing, total_paths = created, skipped, paths + + if self._check_pause_and_cancel(): + cancelled = True + return + + self._emit_event( + "assets.seed.fast_complete", + { + "roots": list(roots), + "created": total_created, + "skipped": skipped_existing, + "total": total_paths, + }, + ) + + # Phase 2: Enrichment scan (metadata + hashes) + if phase in (ScanPhase.ENRICH, ScanPhase.FULL): + if self._check_pause_and_cancel(): + cancelled = True + return + + enrich_cancelled, total_enriched = self._run_enrich_phase(roots) + + if enrich_cancelled: + cancelled = True + return + + self._emit_event( + "assets.seed.enrich_complete", + { + "roots": list(roots), + "enriched": total_enriched, + }, + ) + + elapsed = time.perf_counter() - t_start + logging.info( + "Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d", + roots, + phase.value, + elapsed, + total_created, + total_enriched, + skipped_existing, + ) + + self._emit_event( + "assets.seed.completed", + { + "phase": phase.value, + "total": total_paths, + "created": total_created, + "enriched": total_enriched, + "skipped": skipped_existing, + "elapsed": round(elapsed, 3), + }, + ) + + except Exception as e: + self._add_error(f"Scan failed: {e}") + logging.exception("Asset scan failed") + self._emit_event("assets.seed.error", {"message": str(e)}) + finally: + if cancelled: + self._emit_event( + "assets.seed.cancelled", + { + "scanned": self._progress.scanned if self._progress else 0, + "total": total_paths, + "created": total_created, + }, + ) + with self._lock: + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: + """Run phase 1: fast scan to create stub records. + + Returns: + Tuple of (total_created, skipped_existing, total_paths) + """ + t_fast_start = time.perf_counter() + total_created = 0 + skipped_existing = 0 + + existing_paths: set[str] = set() + t_sync = time.perf_counter() + for r in roots: + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + existing_paths.update(sync_root_safely(r)) + logging.debug( + "Fast scan: sync_root phase took %.3fs (%d existing paths)", + time.perf_counter() - t_sync, + len(existing_paths), + ) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + + t_collect = time.perf_counter() + paths = collect_paths_for_roots(roots) + logging.debug( + "Fast scan: collect_paths took %.3fs (%d paths found)", + time.perf_counter() - t_collect, + len(paths), + ) + total_paths = len(paths) + self._update_progress(total=total_paths) + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "total": total_paths, "phase": "fast"}, + ) + + # Use stub specs (no metadata extraction, no hashing) + t_specs = time.perf_counter() + specs, tag_pool, skipped_existing = build_asset_specs( + paths, + existing_paths, + enable_metadata_extraction=False, + compute_hashes=False, + ) + logging.debug( + "Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)", + time.perf_counter() - t_specs, + len(specs), + skipped_existing, + ) + self._update_progress(skipped=skipped_existing) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, total_paths + + batch_size = 500 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + for i in range(0, len(specs), batch_size): + if self._check_pause_and_cancel(): + logging.info( + "Fast scan cancelled after %d/%d files (created=%d)", + i, + len(specs), + total_created, + ) + return total_created, skipped_existing, total_paths + + batch = specs[i : i + batch_size] + batch_tags = {t for spec in batch for t in spec["tags"]} + try: + created = insert_asset_specs(batch, batch_tags) + total_created += created + except Exception as e: + self._add_error(f"Batch insert failed at offset {i}: {e}") + logging.exception("Batch insert failed at offset %d", i) + + scanned = i + len(batch) + now = time.perf_counter() + self._update_progress(scanned=scanned, created=total_created) + + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "fast", + "scanned": scanned, + "total": len(specs), + "created": total_created, + }, + ) + last_progress_time = now + + self._update_progress(scanned=len(specs), created=total_created) + logging.info( + "Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)", + time.perf_counter() - t_fast_start, + total_created, + skipped_existing, + total_paths, + ) + return total_created, skipped_existing, total_paths + + def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]: + """Run phase 2: enrich existing records with metadata and hashes. + + Returns: + Tuple of (cancelled, total_enriched) + """ + total_enriched = 0 + batch_size = 100 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + # Get the target enrichment level based on compute_hashes + if not self._compute_hashes: + target_max_level = ENRICHMENT_STUB + else: + target_max_level = ENRICHMENT_METADATA + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "phase": "enrich"}, + ) + + skip_ids: set[str] = set() + consecutive_empty = 0 + max_consecutive_empty = 3 + + # Hash checkpoints survive across batches so interrupted hashes + # can be resumed without re-reading the entire file. + hash_checkpoints: dict[str, object] = {} + + while True: + if self._check_pause_and_cancel(): + logging.info("Enrich scan cancelled after %d assets", total_enriched) + return True, total_enriched + + # Fetch next batch of unenriched assets + unenriched = get_unenriched_assets_for_roots( + roots, + max_level=target_max_level, + limit=batch_size, + ) + + # Filter out previously failed references + if skip_ids: + unenriched = [r for r in unenriched if r.reference_id not in skip_ids] + + if not unenriched: + break + + enriched, failed_ids = enrich_assets_batch( + unenriched, + extract_metadata=True, + compute_hash=self._compute_hashes, + interrupt_check=self._is_paused_or_cancelled, + hash_checkpoints=hash_checkpoints, + ) + total_enriched += enriched + skip_ids.update(failed_ids) + + if enriched == 0: + consecutive_empty += 1 + if consecutive_empty >= max_consecutive_empty: + logging.warning( + "Enrich phase stopping: %d consecutive batches with no progress (%d skipped)", + consecutive_empty, + len(skip_ids), + ) + break + else: + consecutive_empty = 0 + + now = time.perf_counter() + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "enrich", + "enriched": total_enriched, + }, + ) + last_progress_time = now + + return False, total_enriched + + +asset_seeder = _AssetSeeder() diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py new file mode 100644 index 000000000..11fcb4122 --- /dev/null +++ b/app/assets/services/__init__.py @@ -0,0 +1,87 @@ +from app.assets.services.asset_management import ( + asset_exists, + delete_asset_reference, + get_asset_by_hash, + get_asset_detail, + list_assets_page, + resolve_asset_for_download, + set_asset_preview, + update_asset_metadata, +) +from app.assets.services.bulk_ingest import ( + BulkInsertResult, + batch_insert_seed_assets, + cleanup_unreferenced_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + get_size_and_mtime_ns, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.ingest import ( + DependencyMissingError, + HashMismatchError, + create_from_hash, + upload_from_temp_path, +) +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, +) +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + IngestResult, + ListAssetsResult, + ReferenceData, + RegisterAssetResult, + TagUsage, + UploadResult, + UserMetadata, +) +from app.assets.services.tagging import ( + apply_tags, + list_tags, + remove_tags, +) + +__all__ = [ + "AddTagsResult", + "AssetData", + "AssetDetailResult", + "AssetSummaryData", + "ReferenceData", + "BulkInsertResult", + "DependencyMissingError", + "DownloadResolutionResult", + "HashMismatchError", + "IngestResult", + "ListAssetsResult", + "RegisterAssetResult", + "RemoveTagsResult", + "TagUsage", + "UploadResult", + "UserMetadata", + "apply_tags", + "asset_exists", + "batch_insert_seed_assets", + "create_from_hash", + "delete_asset_reference", + "get_asset_by_hash", + "get_asset_detail", + "get_mtime_ns", + "get_size_and_mtime_ns", + "list_assets_page", + "list_files_recursively", + "list_tags", + "cleanup_unreferenced_assets", + "remove_tags", + "resolve_asset_for_download", + "set_asset_preview", + "update_asset_metadata", + "upload_from_temp_path", + "verify_file_unchanged", +] diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py new file mode 100644 index 000000000..3fe7115c8 --- /dev/null +++ b/app/assets/services/asset_management.py @@ -0,0 +1,309 @@ +import contextlib +import mimetypes +import os +from typing import Sequence + + +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + reference_exists_for_asset_id, + delete_reference_by_id, + fetch_reference_and_asset, + soft_delete_reference_by_id, + fetch_reference_asset_and_tags, + get_asset_by_hash as queries_get_asset_by_hash, + get_reference_by_id, + get_reference_with_owner_check, + list_references_page, + list_references_by_asset_id, + set_reference_metadata, + set_reference_preview, + set_reference_tags, + update_reference_access_time, + update_reference_name, + update_reference_updated_at, +) +from app.assets.helpers import select_best_live_path +from app.assets.services.path_utils import compute_relative_filename +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + ListAssetsResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def get_asset_detail( + reference_id: str, + owner_id: str = "", +) -> AssetDetailResult | None: + with create_session() as session: + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + return None + + ref, asset, tags = result + return AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + + +def update_asset_metadata( + reference_id: str, + name: str | None = None, + tags: Sequence[str] | None = None, + user_metadata: UserMetadata = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + ref = get_reference_with_owner_check(session, reference_id, owner_id) + + touched = False + if name is not None and name != ref.name: + update_reference_name(session, reference_id=reference_id, name=name) + touched = True + + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + + new_meta: dict | None = None + if user_metadata is not None: + new_meta = dict(user_metadata) + elif computed_filename: + current_meta = ref.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + + if new_meta is not None: + if computed_filename: + new_meta["filename"] = computed_filename + set_reference_metadata( + session, reference_id=reference_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + set_reference_tags( + session, + reference_id=reference_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + update_reference_updated_at(session, reference_id=reference_id) + + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + raise RuntimeError("State changed during update") + + ref, asset, tag_list = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_list, + ) + session.commit() + + return detail + + +def delete_asset_reference( + reference_id: str, + owner_id: str, + delete_content_if_orphan: bool = True, +) -> bool: + with create_session() as session: + if not delete_content_if_orphan: + # Soft delete: mark the reference as deleted but keep everything + deleted = soft_delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + session.commit() + return deleted + + ref_row = get_reference_by_id(session, reference_id=reference_id) + asset_id = ref_row.asset_id if ref_row else None + file_path = ref_row.file_path if ref_row else None + + deleted = delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + if not deleted: + session.commit() + return False + + if not asset_id: + session.commit() + return True + + still_exists = reference_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + session.commit() + return True + + # Orphaned asset - delete it and its files + refs = list_references_by_asset_id(session, asset_id=asset_id) + file_paths = [ + r.file_path for r in (refs or []) if getattr(r, "file_path", None) + ] + # Also include the just-deleted file path + if file_path: + file_paths.append(file_path) + + asset_row = session.get(Asset, asset_id) + if asset_row is not None: + session.delete(asset_row) + + session.commit() + + # Delete files after commit + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + + return True + + +def set_asset_preview( + reference_id: str, + preview_asset_id: str | None = None, + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + set_reference_preview( + session, + reference_id=reference_id, + preview_asset_id=preview_asset_id, + ) + + result = fetch_reference_asset_and_tags( + session, reference_id=reference_id, owner_id=owner_id + ) + if not result: + raise RuntimeError("State changed during preview update") + + ref, asset, tags = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + session.commit() + + return detail + + +def asset_exists(asset_hash: str) -> bool: + with create_session() as session: + return asset_exists_by_hash(session, asset_hash=asset_hash) + + +def get_asset_by_hash(asset_hash: str) -> AssetData | None: + with create_session() as session: + asset = queries_get_asset_by_hash(session, asset_hash=asset_hash) + return extract_asset_data(asset) + + +def list_assets_page( + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> ListAssetsResult: + with create_session() as session: + refs, tag_map, total = list_references_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + items: list[AssetSummaryData] = [] + for ref in refs: + items.append( + AssetSummaryData( + ref=extract_reference_data(ref), + asset=extract_asset_data(ref.asset), + tags=tag_map.get(ref.id, []), + ) + ) + + return ListAssetsResult(items=items, total=total) + + +def resolve_asset_for_download( + reference_id: str, + owner_id: str = "", +) -> DownloadResolutionResult: + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise ValueError(f"AssetReference {reference_id} not found") + + ref, asset = pair + + # For references with file_path, use that directly + if ref.file_path and os.path.isfile(ref.file_path): + abs_path = ref.file_path + else: + # For API-created refs without file_path, find a path from other refs + refs = list_references_by_asset_id(session, asset_id=asset.id) + abs_path = select_best_live_path(refs) + if not abs_path: + raise FileNotFoundError( + f"No live path for AssetReference {reference_id} " + f"(asset id={asset.id}, name={ref.name})" + ) + + # Capture ORM attributes before commit (commit expires loaded objects) + ref_name = ref.name + asset_mime = asset.mime_type + + update_reference_access_time(session, reference_id=reference_id) + session.commit() + + ctype = ( + asset_mime + or mimetypes.guess_type(ref_name or abs_path)[0] + or "application/octet-stream" + ) + download_name = ref_name or os.path.basename(abs_path) + return DownloadResolutionResult( + abs_path=abs_path, + content_type=ctype, + download_name=download_name, + ) diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py new file mode 100644 index 000000000..54e72730c --- /dev/null +++ b/app/assets/services/bulk_ingest.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import os +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, TypedDict + +from sqlalchemy.orm import Session + +from app.assets.database.queries import ( + bulk_insert_assets, + bulk_insert_references_ignore_conflicts, + bulk_insert_tags_and_meta, + delete_assets_by_ids, + get_existing_asset_ids, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_unreferenced_unhashed_asset_ids, + restore_references_by_paths, +) +from app.assets.helpers import get_utc_now + +if TYPE_CHECKING: + from app.assets.services.metadata_extract import ExtractedMetadata + + +class SeedAssetSpec(TypedDict): + """Spec for seeding an asset from filesystem.""" + + abs_path: str + size_bytes: int + mtime_ns: int + info_name: str + tags: list[str] + fname: str + metadata: ExtractedMetadata | None + hash: str | None + mime_type: str | None + + +class AssetRow(TypedDict): + """Row data for inserting an Asset.""" + + id: str + hash: str | None + size_bytes: int + mime_type: str | None + created_at: datetime + + +class ReferenceRow(TypedDict): + """Row data for inserting an AssetReference.""" + + id: str + asset_id: str + file_path: str + mtime_ns: int + owner_id: str + name: str + preview_id: str | None + user_metadata: dict[str, Any] | None + created_at: datetime + updated_at: datetime + last_access_time: datetime + + +class TagRow(TypedDict): + """Row data for inserting a Tag.""" + + asset_reference_id: str + tag_name: str + origin: str + added_at: datetime + + +class MetadataRow(TypedDict): + """Row data for inserting asset metadata.""" + + asset_reference_id: str + key: str + ordinal: int + val_str: str | None + val_num: float | None + val_bool: bool | None + val_json: dict[str, Any] | None + + +@dataclass +class BulkInsertResult: + """Result of bulk asset insertion.""" + + inserted_refs: int + won_paths: int + lost_paths: int + + +def batch_insert_seed_assets( + session: Session, + specs: list[SeedAssetSpec], + owner_id: str = "", +) -> BulkInsertResult: + """Seed assets from filesystem specs in batch. + + Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + + This function orchestrates: + 1. Insert seed Assets (hash=NULL) + 2. Claim references with ON CONFLICT DO NOTHING on file_path + 3. Query to find winners (paths where our asset_id was inserted) + 4. Delete Assets for losers (path already claimed by another asset) + 5. Insert tags and metadata for successfully inserted references + + Returns: + BulkInsertResult with inserted_refs, won_paths, lost_paths + """ + if not specs: + return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + + current_time = get_utc_now() + asset_rows: list[AssetRow] = [] + reference_rows: list[ReferenceRow] = [] + path_to_asset_id: dict[str, str] = {} + asset_id_to_ref_data: dict[str, dict] = {} + absolute_path_list: list[str] = [] + + for spec in specs: + absolute_path = os.path.abspath(spec["abs_path"]) + asset_id = str(uuid.uuid4()) + reference_id = str(uuid.uuid4()) + absolute_path_list.append(absolute_path) + path_to_asset_id[absolute_path] = asset_id + + mime_type = spec.get("mime_type") + asset_rows.append( + { + "id": asset_id, + "hash": spec.get("hash"), + "size_bytes": spec["size_bytes"], + "mime_type": mime_type, + "created_at": current_time, + } + ) + + # Build user_metadata from extracted metadata or fallback to filename + extracted_metadata = spec.get("metadata") + if extracted_metadata: + user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata() + elif spec["fname"]: + user_metadata = {"filename": spec["fname"]} + else: + user_metadata = None + + reference_rows.append( + { + "id": reference_id, + "asset_id": asset_id, + "file_path": absolute_path, + "mtime_ns": spec["mtime_ns"], + "owner_id": owner_id, + "name": spec["info_name"], + "preview_id": None, + "user_metadata": user_metadata, + "created_at": current_time, + "updated_at": current_time, + "last_access_time": current_time, + } + ) + + asset_id_to_ref_data[asset_id] = { + "reference_id": reference_id, + "tags": spec["tags"], + "filename": spec["fname"], + "extracted_metadata": extracted_metadata, + } + + bulk_insert_assets(session, asset_rows) + + # Filter reference rows to only those whose assets were actually inserted + # (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING) + inserted_asset_ids = get_existing_asset_ids( + session, [r["asset_id"] for r in reference_rows] + ) + reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids] + + bulk_insert_references_ignore_conflicts(session, reference_rows) + restore_references_by_paths(session, absolute_path_list) + winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id) + + inserted_paths = { + path + for path in absolute_path_list + if path_to_asset_id[path] in inserted_asset_ids + } + losing_paths = inserted_paths - winning_paths + lost_asset_ids = [path_to_asset_id[path] for path in losing_paths] + + if lost_asset_ids: + delete_assets_by_ids(session, lost_asset_ids) + + if not winning_paths: + return BulkInsertResult( + inserted_refs=0, + won_paths=0, + lost_paths=len(losing_paths), + ) + + # Get reference IDs for winners + winning_ref_ids = [ + asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"] + for path in winning_paths + ] + inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids) + + tag_rows: list[TagRow] = [] + metadata_rows: list[MetadataRow] = [] + + if inserted_ref_ids: + for path in winning_paths: + asset_id = path_to_asset_id[path] + ref_data = asset_id_to_ref_data[asset_id] + ref_id = ref_data["reference_id"] + + if ref_id not in inserted_ref_ids: + continue + + for tag in ref_data["tags"]: + tag_rows.append( + { + "asset_reference_id": ref_id, + "tag_name": tag, + "origin": "automatic", + "added_at": current_time, + } + ) + + # Use extracted metadata for meta rows if available + extracted_metadata = ref_data.get("extracted_metadata") + if extracted_metadata: + metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id)) + elif ref_data["filename"]: + # Fallback: just store filename + metadata_rows.append( + { + "asset_reference_id": ref_id, + "key": "filename", + "ordinal": 0, + "val_str": ref_data["filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows) + + return BulkInsertResult( + inserted_refs=len(inserted_ref_ids), + won_paths=len(winning_paths), + lost_paths=len(losing_paths), + ) + + +def cleanup_unreferenced_assets(session: Session) -> int: + """Hard-delete unhashed assets with no active references. + + This is a destructive operation intended for explicit cleanup. + Only deletes assets where hash=None and all references are missing. + + Returns: + Number of assets deleted + """ + unreferenced_ids = get_unreferenced_unhashed_asset_ids(session) + return delete_assets_by_ids(session, unreferenced_ids) diff --git a/app/assets/services/file_utils.py b/app/assets/services/file_utils.py new file mode 100644 index 000000000..c47ebe460 --- /dev/null +++ b/app/assets/services/file_utils.py @@ -0,0 +1,70 @@ +import os + + +def get_mtime_ns(stat_result: os.stat_result) -> int: + """Extract mtime in nanoseconds from a stat result.""" + return getattr( + stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000) + ) + + +def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]: + """Get file size in bytes and mtime in nanoseconds.""" + st = os.stat(path, follow_symlinks=follow_symlinks) + return st.st_size, get_mtime_ns(st) + + +def verify_file_unchanged( + mtime_db: int | None, + size_db: int | None, + stat_result: os.stat_result, +) -> bool: + """Check if a file is unchanged based on mtime and size. + + Returns True if the file's mtime and size match the database values. + Returns False if mtime_db is None or values don't match. + + size_db=None means don't check size; 0 is a valid recorded size. + """ + if mtime_db is None: + return False + actual_mtime_ns = get_mtime_ns(stat_result) + if int(mtime_db) != int(actual_mtime_ns): + return False + if size_db is not None: + return int(stat_result.st_size) == int(size_db) + return True + + +def is_visible(name: str) -> bool: + """Return True if a file or directory name is visible (not hidden).""" + return not name.startswith(".") + + +def list_files_recursively(base_dir: str) -> list[str]: + """Recursively list all files in a directory, following symlinks.""" + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + # Track seen real directory identities to prevent circular symlink loops + seen_dirs: set[tuple[int, int]] = set() + for dirpath, subdirs, filenames in os.walk( + base_abs, topdown=True, followlinks=True + ): + try: + st = os.stat(dirpath) + dir_id = (st.st_dev, st.st_ino) + except OSError: + subdirs.clear() + continue + if dir_id in seen_dirs: + subdirs.clear() + continue + seen_dirs.add(dir_id) + subdirs[:] = [d for d in subdirs if is_visible(d)] + for name in filenames: + if not is_visible(name): + continue + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py new file mode 100644 index 000000000..92aee6402 --- /dev/null +++ b/app/assets/services/hashing.py @@ -0,0 +1,95 @@ +import io +import os +from contextlib import contextmanager +from dataclasses import dataclass +from typing import IO, Any, Callable, Iterator + +from blake3 import blake3 + +DEFAULT_CHUNK = 8 * 1024 * 1024 + +InterruptCheck = Callable[[], bool] + + +@dataclass +class HashCheckpoint: + """Saved state for resuming an interrupted hash computation.""" + + bytes_processed: int + hasher: Any # blake3 hasher instance + mtime_ns: int = 0 + file_size: int = 0 + + +@contextmanager +def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]: + """Yield (file_object, is_path) with appropriate setup/teardown.""" + if hasattr(fp, "read"): + seekable = getattr(fp, "seekable", lambda: False)() + orig_pos = None + if seekable: + try: + orig_pos = fp.tell() + if orig_pos != 0: + fp.seek(0) + except io.UnsupportedOperation: + orig_pos = None + try: + yield fp, False + finally: + if orig_pos is not None: + fp.seek(orig_pos) + else: + with open(os.fspath(fp), "rb") as f: + yield f, True + + +def compute_blake3_hash( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, + interrupt_check: InterruptCheck | None = None, + checkpoint: HashCheckpoint | None = None, +) -> tuple[str | None, HashCheckpoint | None]: + """Compute BLAKE3 hash of a file, with optional checkpoint support. + + Args: + fp: File path or file-like object + chunk_size: Size of chunks to read at a time + interrupt_check: Optional callable that returns True if the operation + should be interrupted (e.g. paused or cancelled). Must be + non-blocking so file handles are released immediately. Checked + between chunk reads. + checkpoint: Optional checkpoint to resume from (file paths only) + + Returns: + Tuple of (hex_digest, None) on completion, or + (None, checkpoint) on interruption (file paths only), or + (None, None) on interruption of a file object + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + with _open_for_hashing(fp) as (f, is_path): + if checkpoint is not None and is_path: + f.seek(checkpoint.bytes_processed) + h = checkpoint.hasher + bytes_processed = checkpoint.bytes_processed + else: + h = blake3() + bytes_processed = 0 + + while True: + if interrupt_check is not None and interrupt_check(): + if is_path: + return None, HashCheckpoint( + bytes_processed=bytes_processed, + hasher=h, + ) + return None, None + chunk = f.read(chunk_size) + if not chunk: + break + h.update(chunk) + bytes_processed += len(chunk) + + return h.hexdigest(), None diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py new file mode 100644 index 000000000..44d7aef36 --- /dev/null +++ b/app/assets/services/ingest.py @@ -0,0 +1,375 @@ +import contextlib +import logging +import mimetypes +import os +from typing import Any, Sequence + +from sqlalchemy.orm import Session + +import app.assets.services.hashing as hashing +from app.assets.database.queries import ( + add_tags_to_reference, + fetch_reference_and_asset, + get_asset_by_hash, + get_existing_asset_ids, + get_reference_by_file_path, + get_reference_tags, + get_or_create_reference, + remove_missing_tag_for_asset_id, + set_reference_metadata, + set_reference_tags, + upsert_asset, + upsert_reference, + validate_tags_exist, +) +from app.assets.helpers import normalize_tags +from app.assets.services.file_utils import get_size_and_mtime_ns +from app.assets.services.path_utils import ( + compute_relative_filename, + resolve_destination_from_tags, + validate_path_within_base, +) +from app.assets.services.schemas import ( + IngestResult, + RegisterAssetResult, + UploadResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def _ingest_file_from_path( + abs_path: str, + asset_hash: str, + size_bytes: int, + mtime_ns: int, + mime_type: str | None = None, + info_name: str | None = None, + owner_id: str = "", + preview_id: str | None = None, + user_metadata: UserMetadata = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> IngestResult: + locator = os.path.abspath(abs_path) + user_metadata = user_metadata or {} + + asset_created = False + asset_updated = False + ref_created = False + ref_updated = False + reference_id: str | None = None + + with create_session() as session: + if preview_id: + if preview_id not in get_existing_asset_ids(session, [preview_id]): + preview_id = None + + asset, asset_created, asset_updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=size_bytes, + mime_type=mime_type, + ) + + ref_created, ref_updated = upsert_reference( + session, + asset_id=asset.id, + file_path=locator, + name=info_name or os.path.basename(locator), + mtime_ns=mtime_ns, + owner_id=owner_id, + ) + + # Get the reference we just created/updated + ref = get_reference_by_file_path(session, locator) + if ref: + reference_id = ref.id + + if preview_id and ref.preview_id != preview_id: + ref.preview_id = preview_id + + norm = normalize_tags(list(tags)) + if norm: + if require_existing_tags: + validate_tags_exist(session, norm) + add_tags_to_reference( + session, + reference_id=reference_id, + tags=norm, + origin=tag_origin, + create_if_missing=not require_existing_tags, + ) + + _update_metadata_with_filename( + session, + reference_id=reference_id, + file_path=ref.file_path, + current_metadata=ref.user_metadata, + user_metadata=user_metadata, + ) + + try: + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + + session.commit() + + return IngestResult( + asset_created=asset_created, + asset_updated=asset_updated, + ref_created=ref_created, + ref_updated=ref_updated, + reference_id=reference_id, + ) + + +def _register_existing_asset( + asset_hash: str, + name: str, + user_metadata: UserMetadata = None, + tags: list[str] | None = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> RegisterAssetResult: + user_metadata = user_metadata or {} + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"No asset with hash {asset_hash}") + + ref, ref_created = get_or_create_reference( + session, + asset_id=asset.id, + owner_id=owner_id, + name=name, + ) + + if not ref_created: + tag_names = get_reference_tags(session, reference_id=ref.id) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=False, + ) + session.commit() + return result + + new_meta = dict(user_metadata) + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta: + set_reference_metadata( + session, + reference_id=ref.id, + user_metadata=new_meta, + ) + + if tags is not None: + set_reference_tags( + session, + reference_id=ref.id, + tags=tags, + origin=tag_origin, + ) + + tag_names = get_reference_tags(session, reference_id=ref.id) + session.refresh(ref) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=True, + ) + session.commit() + + return result + + + +def _update_metadata_with_filename( + session: Session, + reference_id: str, + file_path: str | None, + current_metadata: dict | None, + user_metadata: dict[str, Any], +) -> None: + computed_filename = compute_relative_filename(file_path) if file_path else None + + current_meta = current_metadata or {} + new_meta = dict(current_meta) + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + set_reference_metadata( + session, + reference_id=reference_id, + user_metadata=new_meta, + ) + + +def _sanitize_filename(name: str | None, fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + return n if n else fallback + + +class HashMismatchError(Exception): + pass + + +class DependencyMissingError(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +def upload_from_temp_path( + temp_path: str, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, + client_filename: str | None = None, + owner_id: str = "", + expected_hash: str | None = None, +) -> UploadResult: + try: + digest, _ = hashing.compute_blake3_hash(temp_path) + except ImportError as e: + raise DependencyMissingError(str(e)) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_hash and asset_hash != expected_hash.strip().lower(): + raise HashMismatchError("Uploaded file hash does not match provided hash.") + + with create_session() as session: + existing = get_asset_by_hash(session, asset_hash=asset_hash) + + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _sanitize_filename(name or client_filename, fallback=digest) + result = _register_existing_asset( + asset_hash=asset_hash, + name=display_name, + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) + + if not tags: + raise ValueError("tags are required for new asset uploads") + base_dir, subdirs = resolve_destination_from_tags(tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + validate_path_within_base(dest_abs, base_dir) + + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + ingest_result = _ingest_file_from_path( + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_sanitize_filename(name or client_filename, fallback=digest), + owner_id=owner_id, + preview_id=None, + user_metadata=user_metadata or {}, + tags=tags, + tag_origin="manual", + require_existing_tags=False, + ) + reference_id = ingest_result.reference_id + if not reference_id: + raise RuntimeError("failed to create asset reference") + + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + ref, asset = pair + tag_names = get_reference_tags(session, reference_id=ref.id) + + return UploadResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created_new=ingest_result.asset_created, + ) + + +def create_from_hash( + hash_str: str, + name: str, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> UploadResult | None: + canonical = hash_str.strip().lower() + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + result = _register_existing_asset( + asset_hash=canonical, + name=_sanitize_filename( + name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical + ), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) diff --git a/app/assets/services/metadata_extract.py b/app/assets/services/metadata_extract.py new file mode 100644 index 000000000..a004929bc --- /dev/null +++ b/app/assets/services/metadata_extract.py @@ -0,0 +1,327 @@ +"""Metadata extraction for asset scanning. + +Tier 1: Filesystem metadata (zero parsing) +Tier 2: Safetensors header metadata (fast JSON read only) +""" + +from __future__ import annotations + +import json +import logging +import mimetypes +import os +import struct +from dataclasses import dataclass +from typing import Any + +from utils.mime_types import init_mime_types + +init_mime_types() + +# Supported safetensors extensions +SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"}) + +# Maximum safetensors header size to read (8MB) +MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024 + + +@dataclass +class ExtractedMetadata: + """Metadata extracted from a file during scanning.""" + + # Tier 1: Filesystem (always available) + filename: str = "" + file_path: str = "" # Full absolute path to the file + content_length: int = 0 + content_type: str | None = None + format: str = "" # file extension without dot + + # Tier 2: Safetensors header (if available) + base_model: str | None = None + trained_words: list[str] | None = None + air: str | None = None # CivitAI AIR identifier + has_preview_images: bool = False + + # Source provenance (populated if embedded in safetensors) + source_url: str | None = None + source_arn: str | None = None + repo_url: str | None = None + preview_url: str | None = None + source_hash: str | None = None + + # HuggingFace specific + repo_id: str | None = None + revision: str | None = None + filepath: str | None = None + resolve_url: str | None = None + + def to_user_metadata(self) -> dict[str, Any]: + """Convert to user_metadata dict for AssetReference.user_metadata JSON field.""" + data: dict[str, Any] = { + "filename": self.filename, + "content_length": self.content_length, + "format": self.format, + } + if self.file_path: + data["file_path"] = self.file_path + if self.content_type: + data["content_type"] = self.content_type + + # Tier 2 fields + if self.base_model: + data["base_model"] = self.base_model + if self.trained_words: + data["trained_words"] = self.trained_words + if self.air: + data["air"] = self.air + if self.has_preview_images: + data["has_preview_images"] = True + + # Source provenance + if self.source_url: + data["source_url"] = self.source_url + if self.source_arn: + data["source_arn"] = self.source_arn + if self.repo_url: + data["repo_url"] = self.repo_url + if self.preview_url: + data["preview_url"] = self.preview_url + if self.source_hash: + data["source_hash"] = self.source_hash + + # HuggingFace + if self.repo_id: + data["repo_id"] = self.repo_id + if self.revision: + data["revision"] = self.revision + if self.filepath: + data["filepath"] = self.filepath + if self.resolve_url: + data["resolve_url"] = self.resolve_url + + return data + + def to_meta_rows(self, reference_id: str) -> list[dict]: + """Convert to asset_reference_meta rows for typed/indexed querying.""" + rows: list[dict] = [] + + def add_str(key: str, val: str | None, ordinal: int = 0) -> None: + if val: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": ordinal, + "val_str": val[:2048] if len(val) > 2048 else val, + "val_num": None, + "val_bool": None, + "val_json": None, + }) + + def add_num(key: str, val: int | float | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": val, + "val_bool": None, + "val_json": None, + }) + + def add_bool(key: str, val: bool | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": None, + "val_bool": val, + "val_json": None, + }) + + # Tier 1 + add_str("filename", self.filename) + add_num("content_length", self.content_length) + add_str("content_type", self.content_type) + add_str("format", self.format) + + # Tier 2 + add_str("base_model", self.base_model) + add_str("air", self.air) + has_previews = self.has_preview_images if self.has_preview_images else None + add_bool("has_preview_images", has_previews) + + # trained_words as multiple rows with ordinals + if self.trained_words: + for i, word in enumerate(self.trained_words[:100]): # limit to 100 words + add_str("trained_words", word, ordinal=i) + + # Source provenance + add_str("source_url", self.source_url) + add_str("source_arn", self.source_arn) + add_str("repo_url", self.repo_url) + add_str("preview_url", self.preview_url) + add_str("source_hash", self.source_hash) + + # HuggingFace + add_str("repo_id", self.repo_id) + add_str("revision", self.revision) + add_str("filepath", self.filepath) + add_str("resolve_url", self.resolve_url) + + return rows + + +def _read_safetensors_header( + path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE +) -> dict[str, Any] | None: + """Read only the JSON header from a safetensors file. + + This is very fast - reads 8 bytes for header length, then the JSON header. + No tensor data is loaded. + + Args: + path: Absolute path to safetensors file + max_size: Maximum header size to read (default 8MB) + + Returns: + Parsed header dict or None if failed + """ + try: + with open(path, "rb") as f: + header_bytes = f.read(8) + if len(header_bytes) < 8: + return None + length_of_header = struct.unpack(" max_size: + return None + header_data = f.read(length_of_header) + if len(header_data) < length_of_header: + return None + return json.loads(header_data.decode("utf-8")) + except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error): + return None + + +def _extract_safetensors_metadata( + header: dict[str, Any], meta: ExtractedMetadata +) -> None: + """Extract metadata from safetensors header __metadata__ section. + + Modifies meta in-place. + """ + st_meta = header.get("__metadata__", {}) + if not isinstance(st_meta, dict): + return + + # Common model metadata + meta.base_model = ( + st_meta.get("ss_base_model_version") + or st_meta.get("modelspec.base_model") + or st_meta.get("base_model") + ) + + # Trained words / trigger words + trained_words = st_meta.get("ss_tag_frequency") + if trained_words and isinstance(trained_words, str): + try: + tag_freq = json.loads(trained_words) + # Extract unique tags from all datasets + all_tags: set[str] = set() + for dataset_tags in tag_freq.values(): + if isinstance(dataset_tags, dict): + all_tags.update(dataset_tags.keys()) + if all_tags: + meta.trained_words = sorted(all_tags)[:100] + except json.JSONDecodeError: + pass + + # Direct trained_words field (some formats) + if not meta.trained_words: + tw = st_meta.get("trained_words") + if isinstance(tw, str): + try: + parsed = json.loads(tw) + if isinstance(parsed, list): + meta.trained_words = [str(x) for x in parsed] + else: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + except json.JSONDecodeError: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + elif isinstance(tw, list): + meta.trained_words = [str(x) for x in tw] + + # CivitAI AIR + meta.air = st_meta.get("air") or st_meta.get("modelspec.air") + + # Preview images (ssmd_cover_images) + cover_images = st_meta.get("ssmd_cover_images") + if cover_images: + meta.has_preview_images = True + + # Source provenance fields + meta.source_url = st_meta.get("source_url") + meta.source_arn = st_meta.get("source_arn") + meta.repo_url = st_meta.get("repo_url") + meta.preview_url = st_meta.get("preview_url") + meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash") + + # HuggingFace fields + meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id") + meta.revision = st_meta.get("revision") or st_meta.get("hf_revision") + meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath") + meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url") + + +def extract_file_metadata( + abs_path: str, + stat_result: os.stat_result | None = None, + relative_filename: str | None = None, +) -> ExtractedMetadata: + """Extract metadata from a file using tier 1 and tier 2 methods. + + Tier 1: Filesystem metadata from path and stat + Tier 2: Safetensors header parsing if applicable + + Args: + abs_path: Absolute path to the file + stat_result: Optional pre-fetched stat result (saves a syscall) + relative_filename: Optional relative filename to use instead of basename + (e.g., "flux/123/model.safetensors" for model paths) + + Returns: + ExtractedMetadata with all available fields populated + """ + meta = ExtractedMetadata() + + # Tier 1: Filesystem metadata + meta.filename = relative_filename or os.path.basename(abs_path) + meta.file_path = abs_path + _, ext = os.path.splitext(abs_path) + meta.format = ext.lstrip(".").lower() if ext else "" + + mime_type, _ = mimetypes.guess_type(abs_path) + meta.content_type = mime_type + + # Size from stat + if stat_result is None: + try: + stat_result = os.stat(abs_path, follow_symlinks=True) + except OSError: + pass + + if stat_result: + meta.content_length = stat_result.st_size + + # Tier 2: Safetensors header (if applicable and enabled) + if ext.lower() in SAFETENSORS_EXTENSIONS: + header = _read_safetensors_header(abs_path) + if header: + try: + _extract_safetensors_metadata(header, meta) + except Exception as e: + logging.debug("Safetensors meta extract failed %s: %s", abs_path, e) + + return meta diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py new file mode 100644 index 000000000..f5dd7f7fd --- /dev/null +++ b/app/assets/services/path_utils.py @@ -0,0 +1,167 @@ +import os +from pathlib import Path +from typing import Literal + +import folder_paths +from app.assets.helpers import normalize_tags + + +_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) + + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build list of (folder_name, base_paths[]) for all model locations. + + Includes every category registered in folder_names_and_paths, + regardless of whether its paths are under the main models_dir, + but excludes non-model entries like custom_nodes. + """ + targets: list[tuple[str, list[str]]] = [] + for name, values in folder_paths.folder_names_and_paths.items(): + if name in _NON_MODEL_FOLDER_NAMES: + continue + paths, _exts = values[0], values[1] + if paths: + targets.append((name, paths)) + return targets + + +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + if not tags: + raise ValueError("tags must not be empty") + root = tags[0].lower() + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + elif root == "input": + base_dir = os.path.abspath(folder_paths.get_input_directory()) + raw_subdirs = tags[1:] + elif root == "output": + base_dir = os.path.abspath(folder_paths.get_output_directory()) + raw_subdirs = tags[1:] + else: + raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") + _sep_chars = frozenset(("/", "\\", os.sep)) + for i in raw_subdirs: + if i in (".", "..") or _sep_chars & set(i): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + + +def validate_path_within_base(candidate: str, base: str) -> None: + cand_abs = Path(os.path.abspath(candidate)) + base_abs = Path(os.path.abspath(base)) + if not cand_abs.is_relative_to(base_abs): + raise ValueError("destination escapes base directory") + + +def compute_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + """ + try: + root_category, rel_path = get_asset_category_and_relative_path(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + + +def get_asset_category_and_relative_path( + file_path: str, +) -> tuple[Literal["input", "output", "models"], str]: + """Determine which root category a file path belongs to. + + Categories: + - 'input': under folder_paths.get_input_directory() + - 'output': under folder_paths.get_output_directory() + - 'models': under any base path from get_comfy_models_folders() + + Returns: + (root_category, relative_path_inside_that_root) + + Raises: + ValueError: path does not belong to any known root. + """ + fp_abs = os.path.abspath(file_path) + + def _check_is_within(child: str, parent: str) -> bool: + return Path(child).is_relative_to(parent) + + def _compute_relative(child: str, parent: str) -> str: + # Normalize relative path, stripping any leading ".." components + # by anchoring to root (os.sep) then computing relpath back from it. + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _check_is_within(fp_abs, input_base): + return "input", _compute_relative(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _check_is_within(fp_abs, output_base): + return "output", _compute_relative(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _check_is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError( + f"Path is not within input, output, or configured model bases: {file_path}" + ) + + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return (name, tags) derived from a filesystem path. + + - name: base filename with extension + - tags: [root_category] + parent folder names in order + + Raises: + ValueError: path does not belong to any known root. + """ + root_category, some_path = get_asset_category_and_relative_path(file_path) + p = Path(some_path) + parent_parts = [ + part for part in p.parent.parts if part not in (".", "..", p.anchor) + ] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py new file mode 100644 index 000000000..8b1f1f4dc --- /dev/null +++ b/app/assets/services/schemas.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, NamedTuple + +from app.assets.database.models import Asset, AssetReference + +UserMetadata = dict[str, Any] | None + + +@dataclass(frozen=True) +class AssetData: + hash: str | None + size_bytes: int | None + mime_type: str | None + + +@dataclass(frozen=True) +class ReferenceData: + """Data transfer object for AssetReference.""" + + id: str + name: str + file_path: str | None + user_metadata: UserMetadata + preview_id: str | None + created_at: datetime + updated_at: datetime + last_access_time: datetime | None + + +@dataclass(frozen=True) +class AssetDetailResult: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class RegisterAssetResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created: bool + + +@dataclass(frozen=True) +class IngestResult: + asset_created: bool + asset_updated: bool + ref_created: bool + ref_updated: bool + reference_id: str | None + + +class TagUsage(NamedTuple): + name: str + tag_type: str + count: int + + +@dataclass(frozen=True) +class AssetSummaryData: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class ListAssetsResult: + items: list[AssetSummaryData] + total: int + + +@dataclass(frozen=True) +class DownloadResolutionResult: + abs_path: str + content_type: str + download_name: str + + +@dataclass(frozen=True) +class UploadResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created_new: bool + + +def extract_reference_data(ref: AssetReference) -> ReferenceData: + return ReferenceData( + id=ref.id, + name=ref.name, + file_path=ref.file_path, + user_metadata=ref.user_metadata, + preview_id=ref.preview_id, + created_at=ref.created_at, + updated_at=ref.updated_at, + last_access_time=ref.last_access_time, + ) + + +def extract_asset_data(asset: Asset | None) -> AssetData | None: + if asset is None: + return None + return AssetData( + hash=asset.hash, + size_bytes=asset.size_bytes, + mime_type=asset.mime_type, + ) diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py new file mode 100644 index 000000000..28900464d --- /dev/null +++ b/app/assets/services/tagging.py @@ -0,0 +1,75 @@ +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, + add_tags_to_reference, + get_reference_with_owner_check, + list_tags_with_usage, + remove_tags_from_reference, +) +from app.assets.services.schemas import TagUsage +from app.database.db import create_session + + +def apply_tags( + reference_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> AddTagsResult: + with create_session() as session: + ref_row = get_reference_with_owner_check(session, reference_id, owner_id) + + result = add_tags_to_reference( + session, + reference_id=reference_id, + tags=tags, + origin=origin, + create_if_missing=True, + reference_row=ref_row, + ) + session.commit() + + return result + + +def remove_tags( + reference_id: str, + tags: list[str], + owner_id: str = "", +) -> RemoveTagsResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + result = remove_tags_from_reference( + session, + reference_id=reference_id, + tags=tags, + ) + session.commit() + + return result + + +def list_tags( + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> tuple[list[TagUsage], int]: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + with create_session() as session: + rows, total = list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total diff --git a/app/database/db.py b/app/database/db.py index 1de8b80ed..0aab09a49 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -3,6 +3,7 @@ import os import shutil from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message +from filelock import FileLock, Timeout from comfy.cli_args import args _DB_AVAILABLE = False @@ -14,8 +15,12 @@ try: from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory - from sqlalchemy import create_engine + from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker + from sqlalchemy.pool import StaticPool + + from app.database.models import Base + import app.assets.database.models # noqa: F401 — register models with Base.metadata _DB_AVAILABLE = True except ImportError as e: @@ -65,9 +70,69 @@ def get_db_path(): raise ValueError(f"Unsupported database URL '{url}'.") +_db_lock = None + +def _acquire_file_lock(db_path): + """Acquire an OS-level file lock to prevent multi-process access. + + Uses filelock for cross-platform support (macOS, Linux, Windows). + The OS automatically releases the lock when the process exits, even on crashes. + """ + global _db_lock + lock_path = db_path + ".lock" + _db_lock = FileLock(lock_path) + try: + _db_lock.acquire(timeout=0) + except Timeout: + raise RuntimeError( + f"Could not acquire lock on database '{db_path}'. " + "Another ComfyUI process may already be using it. " + "Use --database-url to specify a separate database file." + ) + + +def _is_memory_db(db_url): + """Check if the database URL refers to an in-memory SQLite database.""" + return db_url in ("sqlite:///:memory:", "sqlite://") + + def init_db(): db_url = args.database_url logging.debug(f"Database URL: {db_url}") + + if _is_memory_db(db_url): + _init_memory_db(db_url) + else: + _init_file_db(db_url) + + +def _init_memory_db(db_url): + """Initialize an in-memory SQLite database using metadata.create_all. + + Alembic migrations don't work with in-memory SQLite because each + connection gets its own separate database — tables created by Alembic's + internal connection are lost immediately. + """ + engine = create_engine( + db_url, + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + Base.metadata.create_all(engine) + + global Session + Session = sessionmaker(bind=engine) + + +def _init_file_db(db_url): + """Initialize a file-backed SQLite database using Alembic migrations.""" db_path = get_db_path() db_exists = os.path.exists(db_path) @@ -75,6 +140,14 @@ def init_db(): # Check if we need to upgrade engine = create_engine(db_url) + + # Enable foreign key enforcement for SQLite + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + conn = engine.connect() context = MigrationContext.configure(conn) @@ -104,6 +177,12 @@ def init_db(): logging.exception("Error upgrading database: ") raise e + # Acquire an OS-level file lock after migrations are complete. + # Alembic uses its own connection, so we must wait until it's done + # before locking — otherwise our own lock blocks the migration. + conn.close() + _acquire_file_lock(db_path) + global Session Session = sessionmaker(bind=engine) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13079c7bc..e9832acaf 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -232,7 +232,7 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") -parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") +parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index a90a5ca40..9f6918315 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -15,6 +15,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, + "assets": args.enable_assets, } diff --git a/main.py b/main.py index 0f58d57b8..a8fc1a28d 100644 --- a/main.py +++ b/main.py @@ -7,14 +7,16 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -from app.assets.scanner import seed_assets +from app.assets.seeder import asset_seeder import itertools import utils.extra_config +from utils.mime_types import init_mime_types import logging import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags +from app.database.db import init_db, dependencies_available if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -161,6 +163,7 @@ def execute_prestartup_script(): logging.info("") apply_custom_paths() +init_mime_types() if args.enable_manager: comfyui_manager.prestartup() @@ -258,6 +261,7 @@ def prompt_worker(q, server_instance): for k in sensitive: extra_data[k] = sensitive[k] + asset_seeder.pause() e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True @@ -302,6 +306,7 @@ def prompt_worker(q, server_instance): last_gc_collect = current_time need_gc = False hook_breaker_ac10a0.restore_functions() + asset_seeder.resume() async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): @@ -352,12 +357,29 @@ def cleanup_temp(): def setup_database(): try: - from app.database.db import init_db, dependencies_available if dependencies_available(): init_db() - if not args.disable_assets_autoscan: - seed_assets(["models"], enable_logging=True) + if args.enable_assets: + if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True): + logging.info("Background asset scan initiated for models, input, output") except Exception as e: + if "database is locked" in str(e): + logging.error( + "Database is locked. Another ComfyUI process is already using this database.\n" + "To resolve this, specify a separate database file for this instance:\n" + " --database-url sqlite:///path/to/another.db" + ) + sys.exit(1) + if args.enable_assets: + logging.error( + f"Failed to initialize database: {e}\n" + "The --enable-assets flag requires a working database connection.\n" + "To resolve this, try one of the following:\n" + " 1. Install the latest requirements: pip install -r requirements.txt\n" + " 2. Specify an alternative database URL: --database-url sqlite:///path/to/your.db\n" + " 3. Use an in-memory database: --database-url sqlite:///:memory:" + ) + sys.exit(1) logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") @@ -440,5 +462,6 @@ if __name__ == "__main__": event_loop.run_until_complete(x) except KeyboardInterrupt: logging.info("\nStopped server") - - cleanup_temp() + finally: + asset_seeder.shutdown() + cleanup_temp() diff --git a/requirements.txt b/requirements.txt index dc9a9ded0..9527135ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,11 +20,13 @@ tqdm psutil alembic SQLAlchemy +filelock av>=14.2.0 comfy-kitchen>=0.2.7 comfy-aimdo>=0.2.7 requests simpleeval>=1.0.0 +blake3 #non essential dependencies: kornia>=0.7.1 diff --git a/server.py b/server.py index 275bce5a7..76904ebc9 100644 --- a/server.py +++ b/server.py @@ -33,8 +33,8 @@ import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal -from app.assets.scanner import seed_assets -from app.assets.api.routes import register_assets_system +from app.assets.seeder import asset_seeder +from app.assets.api.routes import register_assets_routes from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -197,10 +197,6 @@ class PromptServer(): def __init__(self, loop): PromptServer.instance = self - mimetypes.init() - mimetypes.add_type('application/javascript; charset=utf-8', '.js') - mimetypes.add_type('image/webp', '.webp') - self.user_manager = UserManager() self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() @@ -239,7 +235,11 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") - register_assets_system(self.app, self.user_manager) + if args.enable_assets: + register_assets_routes(self.app, self.user_manager) + else: + register_assets_routes(self.app) + asset_seeder.disable() routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -697,10 +697,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): - try: - seed_assets(["models"]) - except Exception as e: - logging.error(f"Failed to seed assets: {e}") + asset_seeder.start(roots=("models", "input", "output")) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py index 0a57dd7b5..6c5c56113 100644 --- a/tests-unit/assets_test/conftest.py +++ b/tests-unit/assets_test/conftest.py @@ -108,7 +108,7 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest) "main.py", f"--base-directory={str(comfy_tmp_base_dir)}", f"--database-url={db_url}", - "--disable-assets-autoscan", + "--enable-assets", "--listen", "127.0.0.1", "--port", @@ -212,7 +212,7 @@ def asset_factory(http: requests.Session, api_base: str): for aid in created: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) @pytest.fixture @@ -258,14 +258,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str): break for aid in ids: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}", timeout=30) - - -def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: - """Force a fast sync/seed pass by calling the seed endpoint.""" - session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30) - time.sleep(0.2) - - -def get_asset_filename(asset_hash: str, extension: str) -> str: - return asset_hash.removeprefix("blake3:") + extension + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) diff --git a/tests-unit/assets_test/helpers.py b/tests-unit/assets_test/helpers.py new file mode 100644 index 000000000..770e011f4 --- /dev/null +++ b/tests-unit/assets_test/helpers.py @@ -0,0 +1,28 @@ +"""Helper functions for assets integration tests.""" +import time + +import requests + + +def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: + """Force a synchronous sync/seed pass by calling the seed endpoint with wait=true. + + Retries on 409 (already running) until the previous scan finishes. + """ + deadline = time.monotonic() + 60 + while True: + r = session.post( + base_url + "/api/assets/seed?wait=true", + json={"roots": ["models", "input", "output"]}, + timeout=60, + ) + if r.status_code != 409: + assert r.status_code == 200, f"seed endpoint returned {r.status_code}: {r.text}" + return + if time.monotonic() > deadline: + raise TimeoutError("seed endpoint stuck in 409 (already running)") + time.sleep(0.25) + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-unit/assets_test/queries/conftest.py b/tests-unit/assets_test/queries/conftest.py new file mode 100644 index 000000000..4ca0e86a9 --- /dev/null +++ b/tests-unit/assets_test/queries/conftest.py @@ -0,0 +1,20 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture +def session(): + """In-memory SQLite session for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + with Session(engine) as sess: + yield sess + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(): + """Override parent autouse fixture - query tests don't need server cleanup.""" + yield diff --git a/tests-unit/assets_test/queries/test_asset.py b/tests-unit/assets_test/queries/test_asset.py new file mode 100644 index 000000000..08f84cd11 --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset.py @@ -0,0 +1,144 @@ +import uuid + +import pytest +from sqlalchemy.orm import Session + +from app.assets.helpers import get_utc_now +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + get_asset_by_hash, + upsert_asset, + bulk_insert_assets, +) + + +class TestAssetExistsByHash: + @pytest.mark.parametrize( + "setup_hash,query_hash,expected", + [ + (None, "nonexistent", False), # No asset exists + ("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash + (None, "", False), # Null hash in DB doesn't match empty string + ], + ids=["nonexistent", "existing", "null_hash_no_match"], + ) + def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected): + if setup_hash is not None or query_hash == "": + asset = Asset(hash=setup_hash, size_bytes=100) + session.add(asset) + session.commit() + + assert asset_exists_by_hash(session, asset_hash=query_hash) is expected + + +class TestGetAssetByHash: + @pytest.mark.parametrize( + "setup_hash,query_hash,should_find", + [ + (None, "nonexistent", False), + ("blake3:def456", "blake3:def456", True), + ], + ids=["nonexistent", "existing"], + ) + def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find): + if setup_hash is not None: + asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png") + session.add(asset) + session.commit() + + result = get_asset_by_hash(session, asset_hash=query_hash) + if should_find: + assert result is not None + assert result.size_bytes == 200 + assert result.mime_type == "image/png" + else: + assert result is None + + +class TestUpsertAsset: + @pytest.mark.parametrize( + "first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime", + [ + # New asset creation + (None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"), + # Existing asset, same values - no update + (500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"), + # Existing asset with size 0, update with new values + (0, None, 2048, "image/png", False, True, 2048, "image/png"), + # Existing asset, second call with size 0 - no update + (1000, None, 0, None, False, False, 1000, None), + ], + ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"], + ) + def test_upsert_scenarios( + self, + session: Session, + first_size, + first_mime, + second_size, + second_mime, + expect_created, + expect_updated, + final_size, + final_mime, + ): + asset_hash = f"blake3:test_{first_size}_{second_size}" + + # First upsert (if first_size is not None, we're testing the second call) + if first_size is not None: + upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=first_size, + mime_type=first_mime, + ) + session.commit() + + # The upsert call we're testing + asset, created, updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=second_size, + mime_type=second_mime, + ) + session.commit() + + assert created is expect_created + assert updated is expect_updated + assert asset.size_bytes == final_size + assert asset.mime_type == final_mime + + +class TestBulkInsertAssets: + def test_inserts_multiple_assets(self, session: Session): + now = get_utc_now() + rows = [ + {"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": now}, + {"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": now}, + {"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": now}, + ] + bulk_insert_assets(session, rows) + session.commit() + + assets = session.query(Asset).all() + assert len(assets) == 3 + hashes = {a.hash for a in assets} + assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"} + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_assets(session, []) + session.commit() + assert session.query(Asset).count() == 0 + + def test_handles_large_batch(self, session: Session): + """Test chunking logic with more rows than MAX_BIND_PARAMS allows.""" + now = get_utc_now() + rows = [ + {"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": now} + for i in range(200) + ] + bulk_insert_assets(session, rows) + session.commit() + + assert session.query(Asset).count() == 200 diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py new file mode 100644 index 000000000..8f6c7fcdb --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -0,0 +1,517 @@ +import time +import uuid +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta +from app.assets.database.queries import ( + reference_exists_for_asset_id, + get_reference_by_id, + insert_reference, + get_or_create_reference, + update_reference_timestamps, + list_references_page, + fetch_reference_asset_and_tags, + fetch_reference_and_asset, + update_reference_access_time, + set_reference_metadata, + delete_reference_by_id, + set_reference_preview, + bulk_insert_references_ignore_conflicts, + get_reference_ids_by_ids, + ensure_tags_exist, + add_tags_to_reference, +) +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestReferenceExistsForAssetId: + def test_returns_false_when_no_reference(self, session: Session): + asset = _make_asset(session, "hash1") + assert reference_exists_for_asset_id(session, asset_id=asset.id) is False + + def test_returns_true_when_reference_exists(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset) + assert reference_exists_for_asset_id(session, asset_id=asset.id) is True + + +class TestGetReferenceById: + def test_returns_none_for_nonexistent(self, session: Session): + assert get_reference_by_id(session, reference_id="nonexistent") is None + + def test_returns_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="myfile.txt") + + result = get_reference_by_id(session, reference_id=ref.id) + assert result is not None + assert result.name == "myfile.txt" + + +class TestListReferencesPage: + def test_empty_db(self, session: Session): + refs, tag_map, total = list_references_page(session) + assert refs == [] + assert tag_map == {} + assert total == 0 + + def test_returns_references_with_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"]) + session.commit() + + refs, tag_map, total = list_references_page(session) + assert len(refs) == 1 + assert refs[0].id == ref.id + assert set(tag_map[ref.id]) == {"alpha", "beta"} + assert total == 1 + + def test_name_contains_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="model_v1.safetensors") + _make_reference(session, asset, name="config.json") + session.commit() + + refs, _, total = list_references_page(session, name_contains="model") + assert total == 1 + assert refs[0].name == "model_v1.safetensors" + + def test_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="public", owner_id="") + _make_reference(session, asset, name="private", owner_id="user1") + session.commit() + + # Empty owner sees only public + refs, _, total = list_references_page(session, owner_id="") + assert total == 1 + assert refs[0].name == "public" + + # Owner sees both + refs, _, total = list_references_page(session, owner_id="user1") + assert total == 2 + + def test_include_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="tagged") + _make_reference(session, asset, name="untagged") + ensure_tags_exist(session, ["wanted"]) + add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"]) + session.commit() + + refs, _, total = list_references_page(session, include_tags=["wanted"]) + assert total == 1 + assert refs[0].name == "tagged" + + def test_exclude_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="keep") + ref_exclude = _make_reference(session, asset, name="exclude") + ensure_tags_exist(session, ["bad"]) + add_tags_to_reference(session, reference_id=ref_exclude.id, tags=["bad"]) + session.commit() + + refs, _, total = list_references_page(session, exclude_tags=["bad"]) + assert total == 1 + assert refs[0].name == "keep" + + def test_sorting(self, session: Session): + asset = _make_asset(session, "hash1", size=100) + asset2 = _make_asset(session, "hash2", size=500) + _make_reference(session, asset, name="small") + _make_reference(session, asset2, name="large") + session.commit() + + refs, _, _ = list_references_page(session, sort="size", order="desc") + assert refs[0].name == "large" + + refs, _, _ = list_references_page(session, sort="name", order="asc") + assert refs[0].name == "large" + + +class TestFetchReferenceAssetAndTags: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_reference_asset_and_tags(session, "nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["tag1"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["tag1"]) + session.commit() + + result = fetch_reference_asset_and_tags(session, ref.id) + assert result is not None + ret_ref, ret_asset, ret_tags = result + assert ret_ref.id == ref.id + assert ret_asset.id == asset.id + assert ret_tags == ["tag1"] + + +class TestFetchReferenceAndAsset: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_reference_and_asset(session, reference_id="nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + result = fetch_reference_and_asset(session, reference_id=ref.id) + assert result is not None + ret_ref, ret_asset = result + assert ret_ref.id == ref.id + assert ret_asset.id == asset.id + + +class TestUpdateReferenceAccessTime: + def test_updates_last_access_time(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + original_time = ref.last_access_time + session.commit() + + import time + time.sleep(0.01) + + update_reference_access_time(session, reference_id=ref.id) + session.commit() + + session.refresh(ref) + assert ref.last_access_time > original_time + + +class TestDeleteReferenceById: + def test_deletes_existing(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + result = delete_reference_by_id(session, reference_id=ref.id, owner_id="") + assert result is True + assert get_reference_by_id(session, reference_id=ref.id) is None + + def test_returns_false_for_nonexistent(self, session: Session): + result = delete_reference_by_id(session, reference_id="nonexistent", owner_id="") + assert result is False + + def test_respects_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + result = delete_reference_by_id(session, reference_id=ref.id, owner_id="user2") + assert result is False + assert get_reference_by_id(session, reference_id=ref.id) is not None + + +class TestSetReferencePreview: + def test_sets_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + session.commit() + + set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id) + session.commit() + + session.refresh(ref) + assert ref.preview_id == preview_asset.id + + def test_clears_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + ref.preview_id = preview_asset.id + session.commit() + + set_reference_preview(session, reference_id=ref.id, preview_asset_id=None) + session.commit() + + session.refresh(ref) + assert ref.preview_id is None + + def test_raises_for_nonexistent_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None) + + def test_raises_for_nonexistent_preview(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + with pytest.raises(ValueError, match="Preview Asset"): + set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent") + + +class TestInsertReference: + def test_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = insert_reference( + session, asset_id=asset.id, owner_id="user1", name="test.bin" + ) + session.commit() + + assert ref is not None + assert ref.name == "test.bin" + assert ref.owner_id == "user1" + + def test_allows_duplicate_names(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = insert_reference(session, asset_id=asset.id, owner_id="user1", name="dup.bin") + session.commit() + + # Duplicate names are now allowed + ref2 = insert_reference( + session, asset_id=asset.id, owner_id="user1", name="dup.bin" + ) + session.commit() + + assert ref1 is not None + assert ref2 is not None + assert ref1.id != ref2.id + + +class TestGetOrCreateReference: + def test_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref, created = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="new.bin" + ) + session.commit() + + assert created is True + assert ref.name == "new.bin" + + def test_always_creates_new_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref1, created1 = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="existing.bin" + ) + session.commit() + + # Duplicate names are allowed, so always creates new + ref2, created2 = get_or_create_reference( + session, asset_id=asset.id, owner_id="user1", name="existing.bin" + ) + session.commit() + + assert created1 is True + assert created2 is True + assert ref1.id != ref2.id + + +class TestUpdateReferenceTimestamps: + def test_updates_timestamps(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + original_updated_at = ref.updated_at + session.commit() + + time.sleep(0.01) + update_reference_timestamps(session, ref) + session.commit() + + session.refresh(ref) + assert ref.updated_at > original_updated_at + + def test_updates_preview_id(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + ref = _make_reference(session, asset) + session.commit() + + update_reference_timestamps(session, ref, preview_id=preview_asset.id) + session.commit() + + session.refresh(ref) + assert ref.preview_id == preview_asset.id + + +class TestSetReferenceMetadata: + def test_sets_metadata(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"key": "value"} + ) + session.commit() + + session.refresh(ref) + assert ref.user_metadata == {"key": "value"} + # Check metadata table + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "key" + assert meta[0].val_str == "value" + + def test_replaces_existing_metadata(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"old": "data"} + ) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"new": "data"} + ) + session.commit() + + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "new" + + def test_clears_metadata_with_empty_dict(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={"key": "value"} + ) + session.commit() + + set_reference_metadata( + session, reference_id=ref.id, user_metadata={} + ) + session.commit() + + session.refresh(ref) + assert ref.user_metadata == {} + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 0 + + def test_raises_for_nonexistent(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_reference_metadata( + session, reference_id="nonexistent", user_metadata={"key": "value"} + ) + + +class TestBulkInsertReferencesIgnoreConflicts: + def test_inserts_multiple_references(self, session: Session): + asset = _make_asset(session, "hash1") + now = get_utc_now() + rows = [ + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "bulk1.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "bulk2.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + refs = session.query(AssetReference).all() + assert len(refs) == 2 + + def test_allows_duplicate_names(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, name="existing.bin", owner_id="") + session.commit() + + now = get_utc_now() + rows = [ + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "existing.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "id": str(uuid.uuid4()), + "owner_id": "", + "name": "new.bin", + "asset_id": asset.id, + "preview_id": None, + "user_metadata": {}, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + # Duplicate names allowed, so all 3 rows exist + refs = session.query(AssetReference).all() + assert len(refs) == 3 + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_references_ignore_conflicts(session, []) + assert session.query(AssetReference).count() == 0 + + +class TestGetReferenceIdsByIds: + def test_returns_existing_ids(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="a.bin") + ref2 = _make_reference(session, asset, name="b.bin") + session.commit() + + found = get_reference_ids_by_ids(session, [ref1.id, ref2.id, "nonexistent"]) + + assert found == {ref1.id, ref2.id} + + def test_empty_list_returns_empty(self, session: Session): + found = get_reference_ids_by_ids(session, []) + assert found == set() diff --git a/tests-unit/assets_test/queries/test_cache_state.py b/tests-unit/assets_test/queries/test_cache_state.py new file mode 100644 index 000000000..ead60e570 --- /dev/null +++ b/tests-unit/assets_test/queries/test_cache_state.py @@ -0,0 +1,499 @@ +"""Tests for cache_state (AssetReference file path) query functions.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ( + list_references_by_asset_id, + upsert_reference, + get_unreferenced_unhashed_asset_ids, + delete_assets_by_ids, + get_references_for_prefixes, + bulk_update_needs_verify, + delete_references_by_ids, + delete_orphaned_seed_asset, + bulk_insert_references_ignore_conflicts, + get_references_by_paths_and_asset_ids, + mark_references_missing_outside_prefixes, + restore_references_by_paths, +) +from app.assets.helpers import select_best_live_path, get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + file_path: str, + name: str = "test", + mtime_ns: int | None = None, + needs_verify: bool = False, +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + asset_id=asset.id, + file_path=file_path, + name=name, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestListReferencesByAssetId: + def test_returns_empty_for_no_references(self, session: Session): + asset = _make_asset(session, "hash1") + refs = list_references_by_asset_id(session, asset_id=asset.id) + assert list(refs) == [] + + def test_returns_references_for_asset(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/path/a.bin", name="a") + _make_reference(session, asset, "/path/b.bin", name="b") + session.commit() + + refs = list_references_by_asset_id(session, asset_id=asset.id) + paths = [r.file_path for r in refs] + assert set(paths) == {"/path/a.bin", "/path/b.bin"} + + def test_does_not_return_other_assets_references(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + _make_reference(session, asset1, "/path/asset1.bin", name="a1") + _make_reference(session, asset2, "/path/asset2.bin", name="a2") + session.commit() + + refs = list_references_by_asset_id(session, asset_id=asset1.id) + paths = [r.file_path for r in refs] + assert paths == ["/path/asset1.bin"] + + +class TestSelectBestLivePath: + def test_returns_empty_for_empty_list(self): + result = select_best_live_path([]) + assert result == "" + + def test_returns_empty_when_no_files_exist(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, "/nonexistent/path.bin") + session.commit() + + result = select_best_live_path([ref]) + assert result == "" + + def test_prefers_verified_path(self, session: Session, tmp_path): + """needs_verify=False should be preferred.""" + asset = _make_asset(session, "hash1") + + verified_file = tmp_path / "verified.bin" + verified_file.write_bytes(b"data") + + unverified_file = tmp_path / "unverified.bin" + unverified_file.write_bytes(b"data") + + ref_verified = _make_reference( + session, asset, str(verified_file), name="verified", needs_verify=False + ) + ref_unverified = _make_reference( + session, asset, str(unverified_file), name="unverified", needs_verify=True + ) + session.commit() + + refs = [ref_unverified, ref_verified] + result = select_best_live_path(refs) + assert result == str(verified_file) + + def test_falls_back_to_existing_unverified(self, session: Session, tmp_path): + """If all references need verification, return first existing path.""" + asset = _make_asset(session, "hash1") + + existing_file = tmp_path / "exists.bin" + existing_file.write_bytes(b"data") + + ref = _make_reference(session, asset, str(existing_file), needs_verify=True) + session.commit() + + result = select_best_live_path([ref]) + assert result == str(existing_file) + + +class TestSelectBestLivePathWithMocking: + def test_handles_missing_file_path_attr(self): + """Gracefully handle references with None file_path.""" + + class MockRef: + file_path = None + needs_verify = False + + result = select_best_live_path([MockRef()]) + assert result == "" + + +class TestUpsertReference: + @pytest.mark.parametrize( + "initial_mtime,second_mtime,expect_created,expect_updated,final_mtime", + [ + # New reference creation + (None, 12345, True, False, 12345), + # Existing reference, same mtime - no update + (100, 100, False, False, 100), + # Existing reference, different mtime - update + (100, 200, False, True, 200), + ], + ids=["new_reference", "existing_no_change", "existing_update_mtime"], + ) + def test_upsert_scenarios( + self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime + ): + asset = _make_asset(session, "hash1") + file_path = f"/path_{initial_mtime}_{second_mtime}.bin" + name = f"file_{initial_mtime}_{second_mtime}" + + # Create initial reference if needed + if initial_mtime is not None: + upsert_reference(session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=initial_mtime) + session.commit() + + # The upsert call we're testing + created, updated = upsert_reference( + session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=second_mtime + ) + session.commit() + + assert created is expect_created + assert updated is expect_updated + ref = session.query(AssetReference).filter_by(file_path=file_path).one() + assert ref.mtime_ns == final_mtime + + def test_upsert_restores_missing_reference(self, session: Session): + """Upserting a reference that was marked missing should restore it.""" + asset = _make_asset(session, "hash1") + file_path = "/restored/file.bin" + + ref = _make_reference(session, asset, file_path, mtime_ns=100) + ref.is_missing = True + session.commit() + + created, updated = upsert_reference( + session, asset_id=asset.id, file_path=file_path, name="restored", mtime_ns=100 + ) + session.commit() + + assert created is False + assert updated is True + restored_ref = session.query(AssetReference).filter_by(file_path=file_path).one() + assert restored_ref.is_missing is False + + +class TestRestoreReferencesByPaths: + def test_restores_missing_references(self, session: Session): + asset = _make_asset(session, "hash1") + missing_path = "/missing/file.bin" + active_path = "/active/file.bin" + + missing_ref = _make_reference(session, asset, missing_path, name="missing") + missing_ref.is_missing = True + _make_reference(session, asset, active_path, name="active") + session.commit() + + restored = restore_references_by_paths(session, [missing_path]) + session.commit() + + assert restored == 1 + ref = session.query(AssetReference).filter_by(file_path=missing_path).one() + assert ref.is_missing is False + + def test_empty_list_restores_nothing(self, session: Session): + restored = restore_references_by_paths(session, []) + assert restored == 0 + + +class TestMarkReferencesMissingOutsidePrefixes: + def test_marks_references_missing_outside_prefixes(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + valid_dir = tmp_path / "valid" + valid_dir.mkdir() + invalid_dir = tmp_path / "invalid" + invalid_dir.mkdir() + + valid_path = str(valid_dir / "file.bin") + invalid_path = str(invalid_dir / "file.bin") + + _make_reference(session, asset, valid_path, name="valid") + _make_reference(session, asset, invalid_path, name="invalid") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)]) + session.commit() + + assert marked == 1 + all_refs = session.query(AssetReference).all() + assert len(all_refs) == 2 + + valid_ref = next(r for r in all_refs if r.file_path == valid_path) + invalid_ref = next(r for r in all_refs if r.file_path == invalid_path) + assert valid_ref.is_missing is False + assert invalid_ref.is_missing is True + + def test_empty_prefixes_marks_nothing(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/some/path.bin") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, []) + + assert marked == 0 + + +class TestGetUnreferencedUnhashedAssetIds: + def test_returns_unreferenced_unhashed_assets(self, session: Session): + # Unhashed asset (hash=None) with no references (no file_path) + no_refs = _make_asset(session, hash_val=None) + # Unhashed asset with active reference (not unreferenced) + with_active_ref = _make_asset(session, hash_val=None) + _make_reference(session, with_active_ref, "/has/ref.bin", name="has_ref") + # Unhashed asset with only missing reference (should be unreferenced) + with_missing_ref = _make_asset(session, hash_val=None) + missing_ref = _make_reference(session, with_missing_ref, "/missing/ref.bin", name="missing_ref") + missing_ref.is_missing = True + # Regular asset (hash not None) - should not be returned + _make_asset(session, hash_val="blake3:regular") + session.commit() + + unreferenced = get_unreferenced_unhashed_asset_ids(session) + + assert no_refs.id in unreferenced + assert with_missing_ref.id in unreferenced + assert with_active_ref.id not in unreferenced + + +class TestDeleteAssetsByIds: + def test_deletes_assets_and_references(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/test/path.bin", name="test") + session.commit() + + deleted = delete_assets_by_ids(session, [asset.id]) + session.commit() + + assert deleted == 1 + assert session.query(Asset).count() == 0 + assert session.query(AssetReference).count() == 0 + + def test_empty_list_deletes_nothing(self, session: Session): + _make_asset(session, "hash1") + session.commit() + + deleted = delete_assets_by_ids(session, []) + + assert deleted == 0 + assert session.query(Asset).count() == 1 + + +class TestGetReferencesForPrefixes: + def test_returns_references_matching_prefix(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + dir1 = tmp_path / "dir1" + dir1.mkdir() + dir2 = tmp_path / "dir2" + dir2.mkdir() + + path1 = str(dir1 / "file.bin") + path2 = str(dir2 / "file.bin") + + _make_reference(session, asset, path1, name="file1", mtime_ns=100) + _make_reference(session, asset, path2, name="file2", mtime_ns=200) + session.commit() + + rows = get_references_for_prefixes(session, [str(dir1)]) + + assert len(rows) == 1 + assert rows[0].file_path == path1 + + def test_empty_prefixes_returns_empty(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/some/path.bin") + session.commit() + + rows = get_references_for_prefixes(session, []) + + assert rows == [] + + +class TestBulkSetNeedsVerify: + def test_sets_needs_verify_flag(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, "/path1.bin", needs_verify=False) + ref2 = _make_reference(session, asset, "/path2.bin", needs_verify=False) + session.commit() + + updated = bulk_update_needs_verify(session, [ref1.id, ref2.id], True) + session.commit() + + assert updated == 2 + session.refresh(ref1) + session.refresh(ref2) + assert ref1.needs_verify is True + assert ref2.needs_verify is True + + def test_empty_list_updates_nothing(self, session: Session): + updated = bulk_update_needs_verify(session, [], True) + assert updated == 0 + + +class TestDeleteReferencesByIds: + def test_deletes_references_by_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, "/path1.bin") + _make_reference(session, asset, "/path2.bin") + session.commit() + + deleted = delete_references_by_ids(session, [ref1.id]) + session.commit() + + assert deleted == 1 + assert session.query(AssetReference).count() == 1 + + def test_empty_list_deletes_nothing(self, session: Session): + deleted = delete_references_by_ids(session, []) + assert deleted == 0 + + +class TestDeleteOrphanedSeedAsset: + @pytest.mark.parametrize( + "create_asset,expected_deleted,expected_count", + [ + (True, True, 0), # Existing asset gets deleted + (False, False, 0), # Nonexistent returns False + ], + ids=["deletes_existing", "nonexistent_returns_false"], + ) + def test_delete_orphaned_seed_asset( + self, session: Session, create_asset, expected_deleted, expected_count + ): + asset_id = "nonexistent-id" + if create_asset: + asset = _make_asset(session, hash_val=None) + asset_id = asset.id + _make_reference(session, asset, "/test/path.bin", name="test") + session.commit() + + deleted = delete_orphaned_seed_asset(session, asset_id) + if create_asset: + session.commit() + + assert deleted is expected_deleted + assert session.query(Asset).count() == expected_count + + +class TestBulkInsertReferencesIgnoreConflicts: + def test_inserts_multiple_references(self, session: Session): + asset = _make_asset(session, "hash1") + now = get_utc_now() + rows = [ + { + "asset_id": asset.id, + "file_path": "/bulk1.bin", + "name": "bulk1", + "mtime_ns": 100, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "asset_id": asset.id, + "file_path": "/bulk2.bin", + "name": "bulk2", + "mtime_ns": 200, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + assert session.query(AssetReference).count() == 2 + + def test_ignores_conflicts(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "/existing.bin", mtime_ns=100) + session.commit() + + now = get_utc_now() + rows = [ + { + "asset_id": asset.id, + "file_path": "/existing.bin", + "name": "existing", + "mtime_ns": 999, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + { + "asset_id": asset.id, + "file_path": "/new.bin", + "name": "new", + "mtime_ns": 200, + "created_at": now, + "updated_at": now, + "last_access_time": now, + }, + ] + bulk_insert_references_ignore_conflicts(session, rows) + session.commit() + + assert session.query(AssetReference).count() == 2 + existing = session.query(AssetReference).filter_by(file_path="/existing.bin").one() + assert existing.mtime_ns == 100 # Original value preserved + + def test_empty_list_is_noop(self, session: Session): + bulk_insert_references_ignore_conflicts(session, []) + assert session.query(AssetReference).count() == 0 + + +class TestGetReferencesByPathsAndAssetIds: + def test_returns_matching_paths(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + + _make_reference(session, asset1, "/path1.bin") + _make_reference(session, asset2, "/path2.bin") + session.commit() + + path_to_asset = { + "/path1.bin": asset1.id, + "/path2.bin": asset2.id, + } + winners = get_references_by_paths_and_asset_ids(session, path_to_asset) + + assert winners == {"/path1.bin", "/path2.bin"} + + def test_excludes_non_matching_asset_ids(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + + _make_reference(session, asset1, "/path1.bin") + session.commit() + + # Path exists but with different asset_id + path_to_asset = {"/path1.bin": asset2.id} + winners = get_references_by_paths_and_asset_ids(session, path_to_asset) + + assert winners == set() + + def test_empty_dict_returns_empty(self, session: Session): + winners = get_references_by_paths_and_asset_ids(session, {}) + assert winners == set() diff --git a/tests-unit/assets_test/queries/test_metadata.py b/tests-unit/assets_test/queries/test_metadata.py new file mode 100644 index 000000000..6a545e819 --- /dev/null +++ b/tests-unit/assets_test/queries/test_metadata.py @@ -0,0 +1,184 @@ +"""Tests for metadata filtering logic in asset_reference queries.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta +from app.assets.database.queries import list_references_page +from app.assets.database.queries.asset_reference import convert_metadata_to_rows +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str, + metadata: dict | None = None, +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id="", + name=name, + asset_id=asset.id, + user_metadata=metadata, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + + if metadata: + for key, val in metadata.items(): + for row in convert_metadata_to_rows(key, val): + meta_row = AssetReferenceMeta( + asset_reference_id=ref.id, + key=row["key"], + ordinal=row.get("ordinal", 0), + val_str=row.get("val_str"), + val_num=row.get("val_num"), + val_bool=row.get("val_bool"), + val_json=row.get("val_json"), + ) + session.add(meta_row) + session.flush() + + return ref + + +class TestMetadataFilterByType: + """Table-driven tests for metadata filtering by different value types.""" + + @pytest.mark.parametrize( + "match_meta,nomatch_meta,filter_key,filter_val", + [ + # String matching + ({"category": "models"}, {"category": "images"}, "category", "models"), + # Integer matching + ({"epoch": 5}, {"epoch": 10}, "epoch", 5), + # Float matching + ({"score": 0.95}, {"score": 0.5}, "score", 0.95), + # Boolean True matching + ({"enabled": True}, {"enabled": False}, "enabled", True), + # Boolean False matching + ({"enabled": False}, {"enabled": True}, "enabled", False), + ], + ids=["string", "int", "float", "bool_true", "bool_false"], + ) + def test_filter_matches_correct_value( + self, session: Session, match_meta, nomatch_meta, filter_key, filter_val + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "match", match_meta) + _make_reference(session, asset, "nomatch", nomatch_meta) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={filter_key: filter_val} + ) + assert total == 1 + assert refs[0].name == "match" + + @pytest.mark.parametrize( + "stored_meta,filter_key,filter_val", + [ + # String no match + ({"category": "models"}, "category", "other"), + # Int no match + ({"epoch": 5}, "epoch", 99), + # Float no match + ({"score": 0.5}, "score", 0.99), + ], + ids=["string_no_match", "int_no_match", "float_no_match"], + ) + def test_filter_returns_empty_when_no_match( + self, session: Session, stored_meta, filter_key, filter_val + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "item", stored_meta) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={filter_key: filter_val} + ) + assert total == 0 + + +class TestMetadataFilterNull: + """Tests for null/missing key filtering.""" + + @pytest.mark.parametrize( + "match_name,match_meta,nomatch_name,nomatch_meta,filter_key", + [ + # Null matches missing key + ("missing_key", {}, "has_key", {"optional": "value"}, "optional"), + # Null matches explicit null + ("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"), + ], + ids=["missing_key", "explicit_null"], + ) + def test_null_filter_matches( + self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key + ): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, match_name, match_meta) + _make_reference(session, asset, nomatch_name, nomatch_meta) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={filter_key: None}) + assert total == 1 + assert refs[0].name == match_name + + +class TestMetadataFilterList: + """Tests for list-based (OR) filtering.""" + + def test_filter_by_list_matches_any(self, session: Session): + """List values should match ANY of the values (OR).""" + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "cat_a", {"category": "a"}) + _make_reference(session, asset, "cat_b", {"category": "b"}) + _make_reference(session, asset, "cat_c", {"category": "c"}) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={"category": ["a", "b"]}) + assert total == 2 + names = {r.name for r in refs} + assert names == {"cat_a", "cat_b"} + + +class TestMetadataFilterMultipleKeys: + """Tests for multiple filter keys (AND semantics).""" + + def test_multiple_keys_must_all_match(self, session: Session): + """Multiple keys should ALL match (AND).""" + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "match", {"type": "model", "version": 2}) + _make_reference(session, asset, "wrong_type", {"type": "config", "version": 2}) + _make_reference(session, asset, "wrong_version", {"type": "model", "version": 1}) + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={"type": "model", "version": 2} + ) + assert total == 1 + assert refs[0].name == "match" + + +class TestMetadataFilterEmptyDict: + """Tests for empty filter behavior.""" + + def test_empty_filter_returns_all(self, session: Session): + asset = _make_asset(session, "hash1") + _make_reference(session, asset, "a", {"key": "val"}) + _make_reference(session, asset, "b", {}) + session.commit() + + refs, _, total = list_references_page(session, metadata_filter={}) + assert total == 2 diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py new file mode 100644 index 000000000..4ed99aa37 --- /dev/null +++ b/tests-unit/assets_test/queries/test_tags.py @@ -0,0 +1,366 @@ +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, AssetReferenceMeta, Tag +from app.assets.database.queries import ( + ensure_tags_exist, + get_reference_tags, + set_reference_tags, + add_tags_to_reference, + remove_tags_from_reference, + add_missing_tag_for_asset_id, + remove_missing_tag_for_asset_id, + list_tags_with_usage, + bulk_insert_tags_and_meta, +) +from app.assets.helpers import get_utc_now + + +def _make_asset(session: Session, hash_val: str | None = None) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestEnsureTagsExist: + def test_creates_new_tags(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta"], tag_type="user") + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_is_idempotent(self, session: Session): + ensure_tags_exist(session, ["alpha"], tag_type="user") + ensure_tags_exist(session, ["alpha"], tag_type="user") + session.commit() + + assert session.query(Tag).count() == 1 + + def test_normalizes_tags(self, session: Session): + ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"]) + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_empty_list_is_noop(self, session: Session): + ensure_tags_exist(session, []) + session.commit() + assert session.query(Tag).count() == 0 + + def test_tag_type_is_set(self, session: Session): + ensure_tags_exist(session, ["system-tag"], tag_type="system") + session.commit() + + tag = session.query(Tag).filter_by(name="system-tag").one() + assert tag.tag_type == "system" + + +class TestGetReferenceTags: + def test_returns_empty_for_no_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + tags = get_reference_tags(session, reference_id=ref.id) + assert tags == [] + + def test_returns_tags_for_reference(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + ensure_tags_exist(session, ["tag1", "tag2"]) + session.add_all([ + AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag1", origin="manual", added_at=get_utc_now()), + AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag2", origin="manual", added_at=get_utc_now()), + ]) + session.flush() + + tags = get_reference_tags(session, reference_id=ref.id) + assert set(tags) == {"tag1", "tag2"} + + +class TestSetReferenceTags: + def test_adds_new_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) + session.commit() + + assert set(result.added) == {"a", "b"} + assert result.removed == [] + assert set(result.total) == {"a", "b"} + + def test_removes_old_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + set_reference_tags(session, reference_id=ref.id, tags=["a", "b", "c"]) + result = set_reference_tags(session, reference_id=ref.id, tags=["a"]) + session.commit() + + assert result.added == [] + assert set(result.removed) == {"b", "c"} + assert result.total == ["a"] + + def test_replaces_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) + result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"]) + session.commit() + + assert result.added == ["c"] + assert result.removed == ["a"] + assert set(result.total) == {"b", "c"} + + +class TestAddTagsToReference: + def test_adds_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) + session.commit() + + assert set(result.added) == {"x", "y"} + assert result.already_present == [] + + def test_reports_already_present(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["x"]) + result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) + session.commit() + + assert result.added == ["y"] + assert result.already_present == ["x"] + + def test_raises_for_missing_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + add_tags_to_reference(session, reference_id="nonexistent", tags=["x"]) + + +class TestRemoveTagsFromReference: + def test_removes_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"]) + result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"]) + session.commit() + + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] + + def test_reports_not_present(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + + add_tags_to_reference(session, reference_id=ref.id, tags=["a"]) + result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"]) + session.commit() + + assert result.removed == ["a"] + assert result.not_present == ["x"] + + def test_raises_for_missing_reference(self, session: Session): + with pytest.raises(ValueError, match="not found"): + remove_tags_from_reference(session, reference_id="nonexistent", tags=["x"]) + + +class TestMissingTagFunctions: + def test_add_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert "missing" in tags + + def test_add_missing_tag_is_idempotent(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="missing").all() + assert len(links) == 1 + + def test_remove_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + add_missing_tag_for_asset_id(session, asset_id=asset.id) + + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert "missing" not in tags + + +class TestListTagsWithUsage: + def test_returns_tags_with_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session) + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_exclude_zero_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session, include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, total = list_tags_with_usage(session, prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags_with_usage(session, order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_owner_visibility(self, session: Session): + ensure_tags_exist(session, ["shared-tag", "owner-tag"]) + + asset = _make_asset(session, "hash1") + shared_ref = _make_reference(session, asset, name="shared", owner_id="") + owner_ref = _make_reference(session, asset, name="owned", owner_id="user1") + + add_tags_to_reference(session, reference_id=shared_ref.id, tags=["shared-tag"]) + add_tags_to_reference(session, reference_id=owner_ref.id, tags=["owner-tag"]) + session.commit() + + # Empty owner sees only shared + rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 0 + + # User1 sees both + rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 1 + + +class TestBulkInsertTagsAndMeta: + def test_inserts_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"]) + session.commit() + + now = get_utc_now() + tag_rows = [ + {"asset_reference_id": ref.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now}, + {"asset_reference_id": ref.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now}, + ] + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[]) + session.commit() + + tags = get_reference_tags(session, reference_id=ref.id) + assert set(tags) == {"bulk-tag1", "bulk-tag2"} + + def test_inserts_meta(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + session.commit() + + meta_rows = [ + { + "asset_reference_id": ref.id, + "key": "meta-key", + "ordinal": 0, + "val_str": "meta-value", + "val_num": None, + "val_bool": None, + "val_json": None, + }, + ] + bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows) + session.commit() + + meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all() + assert len(meta) == 1 + assert meta[0].key == "meta-key" + assert meta[0].val_str == "meta-value" + + def test_ignores_conflicts(self, session: Session): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["existing-tag"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["existing-tag"]) + session.commit() + + now = get_utc_now() + tag_rows = [ + {"asset_reference_id": ref.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now}, + ] + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[]) + session.commit() + + # Should still have only one tag link + links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="existing-tag").all() + assert len(links) == 1 + # Origin should be original, not overwritten + assert links[0].origin == "manual" + + def test_empty_lists_is_noop(self, session: Session): + bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[]) + assert session.query(AssetReferenceTag).count() == 0 + assert session.query(AssetReferenceMeta).count() == 0 diff --git a/tests-unit/assets_test/services/__init__.py b/tests-unit/assets_test/services/__init__.py new file mode 100644 index 000000000..d0213422e --- /dev/null +++ b/tests-unit/assets_test/services/__init__.py @@ -0,0 +1 @@ +# Service layer tests diff --git a/tests-unit/assets_test/services/conftest.py b/tests-unit/assets_test/services/conftest.py new file mode 100644 index 000000000..31c763d48 --- /dev/null +++ b/tests-unit/assets_test/services/conftest.py @@ -0,0 +1,54 @@ +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(): + """Override parent autouse fixture - service unit tests don't need server cleanup.""" + yield + + +@pytest.fixture +def db_engine(): + """In-memory SQLite engine for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + """Session fixture for tests that need direct DB access.""" + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def mock_create_session(db_engine): + """Patch create_session to use our in-memory database.""" + from contextlib import contextmanager + from sqlalchemy.orm import Session as SASession + + @contextmanager + def _create_session(): + with SASession(db_engine) as sess: + yield sess + + with patch("app.assets.services.ingest.create_session", _create_session), \ + patch("app.assets.services.asset_management.create_session", _create_session), \ + patch("app.assets.services.tagging.create_session", _create_session): + yield _create_session + + +@pytest.fixture +def temp_dir(): + """Temporary directory for file operations.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) diff --git a/tests-unit/assets_test/services/test_asset_management.py b/tests-unit/assets_test/services/test_asset_management.py new file mode 100644 index 000000000..101ef7292 --- /dev/null +++ b/tests-unit/assets_test/services/test_asset_management.py @@ -0,0 +1,268 @@ +"""Tests for asset_management services.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services import ( + get_asset_detail, + update_asset_metadata, + delete_asset_reference, + set_asset_preview, +) + + +def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestGetAssetDetail: + def test_returns_none_for_nonexistent(self, mock_create_session): + result = get_asset_detail(reference_id="nonexistent") + assert result is None + + def test_returns_asset_with_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"]) + session.commit() + + result = get_asset_detail(reference_id=ref.id) + + assert result is not None + assert result.ref.id == ref.id + assert result.asset.hash == asset.hash + assert set(result.tags) == {"alpha", "beta"} + + def test_respects_owner_visibility(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + # Wrong owner cannot see + result = get_asset_detail(reference_id=ref.id, owner_id="user2") + assert result is None + + # Correct owner can see + result = get_asset_detail(reference_id=ref.id, owner_id="user1") + assert result is not None + + +class TestUpdateAssetMetadata: + def test_updates_name(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, name="old_name.bin") + ref_id = ref.id + session.commit() + + update_asset_metadata( + reference_id=ref_id, + name="new_name.bin", + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.name == "new_name.bin" + + def test_updates_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["old"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["old"]) + session.commit() + + result = update_asset_metadata( + reference_id=ref.id, + tags=["new1", "new2"], + ) + + assert set(result.tags) == {"new1", "new2"} + assert "old" not in result.tags + + def test_updates_user_metadata(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ref_id = ref.id + session.commit() + + update_asset_metadata( + reference_id=ref_id, + user_metadata={"key": "value", "num": 42}, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.user_metadata["key"] == "value" + assert updated_ref.user_metadata["num"] == 42 + + def test_raises_for_nonexistent(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + update_asset_metadata(reference_id="nonexistent", name="fail") + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + update_asset_metadata( + reference_id=ref.id, + name="new", + owner_id="user2", + ) + + +class TestDeleteAssetReference: + def test_soft_deletes_reference(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ref_id = ref.id + session.commit() + + result = delete_asset_reference( + reference_id=ref_id, + owner_id="", + delete_content_if_orphan=False, + ) + + assert result is True + # Row still exists but is marked as soft-deleted + session.expire_all() + row = session.get(AssetReference, ref_id) + assert row is not None + assert row.deleted_at is not None + + def test_returns_false_for_nonexistent(self, mock_create_session): + result = delete_asset_reference( + reference_id="nonexistent", + owner_id="", + ) + assert result is False + + def test_returns_false_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + ref_id = ref.id + session.commit() + + result = delete_asset_reference( + reference_id=ref_id, + owner_id="user2", + ) + + assert result is False + assert session.get(AssetReference, ref_id) is not None + + def test_keeps_asset_if_other_references_exist(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref1 = _make_reference(session, asset, name="ref1") + _make_reference(session, asset, name="ref2") # Second ref keeps asset alive + asset_id = asset.id + session.commit() + + delete_asset_reference( + reference_id=ref1.id, + owner_id="", + delete_content_if_orphan=True, + ) + + # Asset should still exist + assert session.get(Asset, asset_id) is not None + + def test_deletes_orphaned_asset(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + asset_id = asset.id + ref_id = ref.id + session.commit() + + delete_asset_reference( + reference_id=ref_id, + owner_id="", + delete_content_if_orphan=True, + ) + + # Both ref and asset should be gone + assert session.get(AssetReference, ref_id) is None + assert session.get(Asset, asset_id) is None + + +class TestSetAssetPreview: + def test_sets_preview(self, mock_create_session, session: Session): + asset = _make_asset(session, hash_val="blake3:main") + preview_asset = _make_asset(session, hash_val="blake3:preview") + ref = _make_reference(session, asset) + ref_id = ref.id + preview_id = preview_asset.id + session.commit() + + set_asset_preview( + reference_id=ref_id, + preview_asset_id=preview_id, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.preview_id == preview_id + + def test_clears_preview(self, mock_create_session, session: Session): + asset = _make_asset(session) + preview_asset = _make_asset(session, hash_val="blake3:preview") + ref = _make_reference(session, asset) + ref.preview_id = preview_asset.id + ref_id = ref.id + session.commit() + + set_asset_preview( + reference_id=ref_id, + preview_asset_id=None, + ) + + # Verify by re-fetching from DB + session.expire_all() + updated_ref = session.get(AssetReference, ref_id) + assert updated_ref.preview_id is None + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + set_asset_preview(reference_id="nonexistent") + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + set_asset_preview( + reference_id=ref.id, + preview_asset_id=None, + owner_id="user2", + ) diff --git a/tests-unit/assets_test/services/test_bulk_ingest.py b/tests-unit/assets_test/services/test_bulk_ingest.py new file mode 100644 index 000000000..26e22a01d --- /dev/null +++ b/tests-unit/assets_test/services/test_bulk_ingest.py @@ -0,0 +1,137 @@ +"""Tests for bulk ingest services.""" + +from pathlib import Path + +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets + + +class TestBatchInsertSeedAssets: + def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path): + """Verify mime_type is stored in the Asset table for model files.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"fake safetensors content") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 24, + "mtime_ns": 1234567890000000000, + "info_name": "Test Model", + "tags": ["models"], + "fname": "model.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + + # Verify Asset has mime_type populated + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].mime_type == "application/safetensors" + + def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path): + """Verify mime_type is None when not provided in spec.""" + file_path = temp_dir / "unknown.bin" + file_path.write_bytes(b"binary data") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 11, + "mtime_ns": 1234567890000000000, + "info_name": "Unknown File", + "tags": [], + "fname": "unknown.bin", + "metadata": None, + "hash": None, + "mime_type": None, + } + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].mime_type is None + + def test_various_model_mime_types(self, session: Session, temp_dir: Path): + """Verify various model file types get correct mime_type.""" + test_cases = [ + ("model.safetensors", "application/safetensors"), + ("model.pt", "application/pytorch"), + ("model.ckpt", "application/pickle"), + ("model.gguf", "application/gguf"), + ] + + specs: list[SeedAssetSpec] = [] + for filename, mime_type in test_cases: + file_path = temp_dir / filename + file_path.write_bytes(b"content") + specs.append( + { + "abs_path": str(file_path), + "size_bytes": 7, + "mtime_ns": 1234567890000000000, + "info_name": filename, + "tags": [], + "fname": filename, + "metadata": None, + "hash": None, + "mime_type": mime_type, + } + ) + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == len(test_cases) + + for filename, expected_mime in test_cases: + ref = session.query(AssetReference).filter_by(name=filename).first() + assert ref is not None + asset = session.query(Asset).filter_by(id=ref.asset_id).first() + assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}" + + +class TestMetadataExtraction: + def test_extracts_mime_type_for_model_files(self, temp_dir: Path): + """Verify metadata extraction returns correct mime_type for model files.""" + from app.assets.services.metadata_extract import extract_file_metadata + + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"fake safetensors content") + + meta = extract_file_metadata(str(file_path)) + + assert meta.content_type == "application/safetensors" + + def test_mime_type_for_various_model_formats(self, temp_dir: Path): + """Verify various model file types get correct mime_type from metadata.""" + from app.assets.services.metadata_extract import extract_file_metadata + + test_cases = [ + ("model.safetensors", "application/safetensors"), + ("model.sft", "application/safetensors"), + ("model.pt", "application/pytorch"), + ("model.pth", "application/pytorch"), + ("model.ckpt", "application/pickle"), + ("model.pkl", "application/pickle"), + ("model.gguf", "application/gguf"), + ] + + for filename, expected_mime in test_cases: + file_path = temp_dir / filename + file_path.write_bytes(b"content") + + meta = extract_file_metadata(str(file_path)) + + assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}" diff --git a/tests-unit/assets_test/services/test_enrich.py b/tests-unit/assets_test/services/test_enrich.py new file mode 100644 index 000000000..2bd79a01a --- /dev/null +++ b/tests-unit/assets_test/services/test_enrich.py @@ -0,0 +1,207 @@ +"""Tests for asset enrichment (mime_type and hash population).""" +from pathlib import Path + +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.scanner import ( + ENRICHMENT_HASHED, + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + enrich_asset, +) + + +def _create_stub_asset( + session: Session, + file_path: str, + asset_id: str = "test-asset-id", + reference_id: str = "test-ref-id", + name: str | None = None, +) -> tuple[Asset, AssetReference]: + """Create a stub asset with reference for testing enrichment.""" + asset = Asset( + id=asset_id, + hash=None, + size_bytes=100, + mime_type=None, + ) + session.add(asset) + session.flush() + + ref = AssetReference( + id=reference_id, + asset_id=asset_id, + name=name or f"test-asset-{asset_id}", + owner_id="system", + file_path=file_path, + mtime_ns=1234567890000000000, + enrichment_level=ENRICHMENT_STUB, + ) + session.add(ref) + session.flush() + + return asset, ref + + +class TestEnrichAsset: + def test_extracts_mime_type_and_updates_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify mime_type is written to the Asset table during enrichment.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"\x00" * 100) + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-1", "ref-1" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=False, + ) + + assert new_level == ENRICHMENT_METADATA + + session.expire_all() + updated_asset = session.get(Asset, "asset-1") + assert updated_asset is not None + assert updated_asset.mime_type == "application/safetensors" + + def test_computes_hash_and_updates_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify hash is written to the Asset table during enrichment.""" + file_path = temp_dir / "data.bin" + file_path.write_bytes(b"test content for hashing") + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-2", "ref-2" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + assert new_level == ENRICHMENT_HASHED + + session.expire_all() + updated_asset = session.get(Asset, "asset-2") + assert updated_asset is not None + assert updated_asset.hash is not None + assert updated_asset.hash.startswith("blake3:") + + def test_enrichment_updates_both_mime_and_hash( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify both mime_type and hash are set when full enrichment runs.""" + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"\x00" * 50) + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-3", "ref-3" + ) + session.commit() + + enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + session.expire_all() + updated_asset = session.get(Asset, "asset-3") + assert updated_asset is not None + assert updated_asset.mime_type == "application/safetensors" + assert updated_asset.hash is not None + assert updated_asset.hash.startswith("blake3:") + + def test_missing_file_returns_stub_level( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify missing files don't cause errors and return STUB level.""" + file_path = temp_dir / "nonexistent.bin" + + asset, ref = _create_stub_asset( + session, str(file_path), "asset-4", "ref-4" + ) + session.commit() + + new_level = enrich_asset( + session, + file_path=str(file_path), + reference_id=ref.id, + asset_id=asset.id, + extract_metadata=True, + compute_hash=True, + ) + + assert new_level == ENRICHMENT_STUB + + session.expire_all() + updated_asset = session.get(Asset, "asset-4") + assert updated_asset.mime_type is None + assert updated_asset.hash is None + + def test_duplicate_hash_merges_into_existing_asset( + self, db_engine, temp_dir: Path, session: Session + ): + """Verify duplicate files merge into existing asset instead of failing.""" + file_path_1 = temp_dir / "file1.bin" + file_path_2 = temp_dir / "file2.bin" + content = b"identical content" + file_path_1.write_bytes(content) + file_path_2.write_bytes(content) + + asset1, ref1 = _create_stub_asset( + session, str(file_path_1), "asset-dup-1", "ref-dup-1" + ) + asset2, ref2 = _create_stub_asset( + session, str(file_path_2), "asset-dup-2", "ref-dup-2" + ) + session.commit() + + enrich_asset( + session, + file_path=str(file_path_1), + reference_id=ref1.id, + asset_id=asset1.id, + extract_metadata=True, + compute_hash=True, + ) + + enrich_asset( + session, + file_path=str(file_path_2), + reference_id=ref2.id, + asset_id=asset2.id, + extract_metadata=True, + compute_hash=True, + ) + + session.expire_all() + + updated_asset1 = session.get(Asset, "asset-dup-1") + assert updated_asset1 is not None + assert updated_asset1.hash is not None + + updated_asset2 = session.get(Asset, "asset-dup-2") + assert updated_asset2 is None + + updated_ref2 = session.get(AssetReference, "ref-dup-2") + assert updated_ref2 is not None + assert updated_ref2.asset_id == "asset-dup-1" diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py new file mode 100644 index 000000000..367bc7721 --- /dev/null +++ b/tests-unit/assets_test/services/test_ingest.py @@ -0,0 +1,229 @@ +"""Tests for ingest services.""" +from pathlib import Path + +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference, Tag +from app.assets.database.queries import get_reference_tags +from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset + + +class TestIngestFileFromPath: + def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "test_file.bin" + file_path.write_bytes(b"test content") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:abc123", + size_bytes=12, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + ) + + assert result.asset_created is True + assert result.ref_created is True + assert result.reference_id is not None + + # Verify DB state + assets = session.query(Asset).all() + assert len(assets) == 1 + assert assets[0].hash == "blake3:abc123" + + refs = session.query(AssetReference).all() + assert len(refs) == 1 + assert refs[0].file_path == str(file_path) + + def test_creates_reference_when_name_provided(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "model.safetensors" + file_path.write_bytes(b"model data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:def456", + size_bytes=10, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + info_name="My Model", + owner_id="user1", + ) + + assert result.asset_created is True + assert result.reference_id is not None + + ref = session.query(AssetReference).first() + assert ref is not None + assert ref.name == "My Model" + assert ref.owner_id == "user1" + + def test_creates_tags_when_provided(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "tagged.bin" + file_path.write_bytes(b"data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:ghi789", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="Tagged Asset", + tags=["models", "checkpoints"], + ) + + assert result.reference_id is not None + + # Verify tags were created and linked + tags = session.query(Tag).all() + tag_names = {t.name for t in tags} + assert "models" in tag_names + assert "checkpoints" in tag_names + + ref_tags = get_reference_tags(session, reference_id=result.reference_id) + assert set(ref_tags) == {"models", "checkpoints"} + + def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "dup.bin" + file_path.write_bytes(b"content") + + # First ingest + r1 = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:repeat", + size_bytes=7, + mtime_ns=1234567890000000000, + ) + assert r1.asset_created is True + + # Second ingest with same hash - should update, not create + r2 = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:repeat", + size_bytes=7, + mtime_ns=1234567890000000001, # different mtime + ) + assert r2.asset_created is False + assert r2.ref_created is False + assert r2.ref_updated is True + + # Still only one asset + assets = session.query(Asset).all() + assert len(assets) == 1 + + def test_validates_preview_id(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "with_preview.bin" + file_path.write_bytes(b"data") + + # Create a preview asset first + preview_asset = Asset(hash="blake3:preview", size_bytes=100) + session.add(preview_asset) + session.commit() + preview_id = preview_asset.id + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:main", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="With Preview", + preview_id=preview_id, + ) + + assert result.reference_id is not None + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.preview_id == preview_id + + def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session): + file_path = temp_dir / "bad_preview.bin" + file_path.write_bytes(b"data") + + result = _ingest_file_from_path( + abs_path=str(file_path), + asset_hash="blake3:badpreview", + size_bytes=4, + mtime_ns=1234567890000000000, + info_name="Bad Preview", + preview_id="nonexistent-uuid", + ) + + assert result.reference_id is not None + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.preview_id is None + + +class TestRegisterExistingAsset: + def test_creates_reference_for_existing_asset(self, mock_create_session, session: Session): + # Create existing asset + asset = Asset(hash="blake3:existing", size_bytes=1024, mime_type="image/png") + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:existing", + name="Registered Asset", + user_metadata={"key": "value"}, + tags=["models"], + ) + + assert result.created is True + assert "models" in result.tags + + # Verify by re-fetching from DB + session.expire_all() + refs = session.query(AssetReference).filter_by(name="Registered Asset").all() + assert len(refs) == 1 + + def test_creates_new_reference_even_with_same_name(self, mock_create_session, session: Session): + # Create asset and reference + asset = Asset(hash="blake3:withref", size_bytes=512) + session.add(asset) + session.flush() + + from app.assets.helpers import get_utc_now + ref = AssetReference( + owner_id="", + name="Existing Ref", + asset_id=asset.id, + created_at=get_utc_now(), + updated_at=get_utc_now(), + last_access_time=get_utc_now(), + ) + session.add(ref) + session.flush() + ref_id = ref.id + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:withref", + name="Existing Ref", + owner_id="", + ) + + # Multiple files with same name are allowed + assert result.created is True + + # Verify two AssetReferences exist for this name + session.expire_all() + refs = session.query(AssetReference).filter_by(name="Existing Ref").all() + assert len(refs) == 2 + assert ref_id in [r.id for r in refs] + + def test_raises_for_nonexistent_hash(self, mock_create_session): + with pytest.raises(ValueError, match="No asset with hash"): + _register_existing_asset( + asset_hash="blake3:doesnotexist", + name="Fail", + ) + + def test_applies_tags_to_new_reference(self, mock_create_session, session: Session): + asset = Asset(hash="blake3:tagged", size_bytes=256) + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:tagged", + name="Tagged Ref", + tags=["alpha", "beta"], + ) + + assert result.created is True + assert set(result.tags) == {"alpha", "beta"} diff --git a/tests-unit/assets_test/services/test_tagging.py b/tests-unit/assets_test/services/test_tagging.py new file mode 100644 index 000000000..ab69e5dc1 --- /dev/null +++ b/tests-unit/assets_test/services/test_tagging.py @@ -0,0 +1,197 @@ +"""Tests for tagging services.""" +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services import apply_tags, remove_tags, list_tags + + +def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestApplyTags: + def test_adds_new_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + session.commit() + + result = apply_tags( + reference_id=ref.id, + tags=["alpha", "beta"], + ) + + assert set(result.added) == {"alpha", "beta"} + assert result.already_present == [] + assert set(result.total_tags) == {"alpha", "beta"} + + def test_reports_already_present(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["existing"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["existing"]) + session.commit() + + result = apply_tags( + reference_id=ref.id, + tags=["existing", "new"], + ) + + assert result.added == ["new"] + assert result.already_present == ["existing"] + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + apply_tags(reference_id="nonexistent", tags=["x"]) + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + apply_tags( + reference_id=ref.id, + tags=["new"], + owner_id="user2", + ) + + +class TestRemoveTags: + def test_removes_tags(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["a", "b", "c"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"]) + session.commit() + + result = remove_tags( + reference_id=ref.id, + tags=["a", "b"], + ) + + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] + + def test_reports_not_present(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset) + ensure_tags_exist(session, ["present"]) + add_tags_to_reference(session, reference_id=ref.id, tags=["present"]) + session.commit() + + result = remove_tags( + reference_id=ref.id, + tags=["present", "absent"], + ) + + assert result.removed == ["present"] + assert result.not_present == ["absent"] + + def test_raises_for_nonexistent_ref(self, mock_create_session): + with pytest.raises(ValueError, match="not found"): + remove_tags(reference_id="nonexistent", tags=["x"]) + + def test_raises_for_wrong_owner(self, mock_create_session, session: Session): + asset = _make_asset(session) + ref = _make_reference(session, asset, owner_id="user1") + session.commit() + + with pytest.raises(PermissionError, match="not owner"): + remove_tags( + reference_id=ref.id, + tags=["x"], + owner_id="user2", + ) + + +class TestListTags: + def test_returns_tags_with_counts(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + asset = _make_asset(session) + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags() + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_excludes_zero_counts(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + asset = _make_asset(session) + ref = _make_reference(session, asset) + add_tags_to_reference(session, reference_id=ref.id, tags=["used"]) + session.commit() + + rows, total = list_tags(include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, _ = list_tags(prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags(order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_pagination(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["a", "b", "c", "d", "e"]) + session.commit() + + rows, total = list_tags(limit=2, offset=1, order="name_asc") + + assert total == 5 + assert len(rows) == 2 + names = [name for name, _, _ in rows] + assert names == ["b", "c"] + + def test_clamps_limit(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["a"]) + session.commit() + + # Service should clamp limit to max 1000 + rows, _ = list_tags(limit=2000) + assert len(rows) <= 1000 diff --git a/tests-unit/assets_test/test_assets_missing_sync.py b/tests-unit/assets_test/test_assets_missing_sync.py index 78fa7b404..47dc130cb 100644 --- a/tests-unit/assets_test/test_assets_missing_sync.py +++ b/tests-unit/assets_test/test_assets_missing_sync.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py index d2b69f475..07310223e 100644 --- a/tests-unit/assets_test/test_crud.py +++ b/tests-unit/assets_test/test_crud.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets def test_create_from_hash_success( @@ -24,11 +24,11 @@ def test_create_from_hash_success( assert b1["created_new"] is False aid = b1["id"] - # Calling again with the same name should return the same AssetInfo id + # Calling again with the same name creates a new AssetReference (duplicates allowed) r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) b2 = r2.json() assert r2.status_code == 201, b2 - assert b2["id"] == aid + assert b2["id"] != aid # new reference, not the same one def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict): @@ -42,8 +42,8 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert "user_metadata" in detail assert "filename" in detail["user_metadata"] - # DELETE - rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + # DELETE (hard delete to also remove underlying asset and file) + rd = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) assert rd.status_code == 204 # GET again -> 404 @@ -53,6 +53,35 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert body["error"]["code"] == "ASSET_NOT_FOUND" +def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + asset_hash = seeded_asset["asset_hash"] + + # Soft-delete (default, no delete_content param) + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + assert rd.status_code == 204 + + # GET by reference ID -> 404 (soft-deleted references are hidden) + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + assert rg.status_code == 404 + + # Asset identity is preserved (underlying content still exists) + rh = http.head(f"{api_base}/api/assets/hash/{asset_hash}", timeout=120) + assert rh.status_code == 200 + + # Soft-deleted reference should not appear in listings + rl = http.get( + f"{api_base}/api/assets", + params={"include_tags": "unit-tests", "limit": "500"}, + timeout=120, + ) + ids = [a["id"] for a in rl.json().get("assets", [])] + assert aid not in ids + + # Clean up: hard-delete the soft-deleted reference and orphaned asset + http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + + def test_delete_upon_reference_count( http: requests.Session, api_base: str, seeded_asset: dict ): @@ -70,21 +99,32 @@ def test_delete_upon_reference_count( assert copy["asset_hash"] == src_hash assert copy["created_new"] is False - # Delete original reference -> asset identity must remain + # Soft-delete original reference (default) -> asset identity must remain aid1 = seeded_asset["id"] rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120) assert rd1.status_code == 204 rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh1.status_code == 200 # identity still present + assert rh1.status_code == 200 # identity still present (second ref exists) - # Delete the last reference with default semantics -> identity and cached files removed + # Soft-delete the last reference -> asset identity preserved (no hard delete) aid2 = copy["id"] rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120) assert rd2.status_code == 204 rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh2.status_code == 404 # orphan content removed + assert rh2.status_code == 200 # asset identity preserved (soft delete) + + # Re-associate via from-hash, then hard-delete -> orphan content removed + r3 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + assert r3.status_code == 201, r3.json() + aid3 = r3.json()["id"] + + rd3 = http.delete(f"{api_base}/api/assets/{aid3}?delete_content=true", timeout=120) + assert rd3.status_code == 204 + + rh3 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh3.status_code == 404 # orphan content removed def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): @@ -126,42 +166,52 @@ def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api assert body == b"" -def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str): - bogus = str(uuid.uuid4()) - r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120) +@pytest.mark.parametrize( + "method,endpoint_template,payload,expected_status,error_code", + [ + # Delete nonexistent asset + ("delete", "/api/assets/{uuid}", None, 404, "ASSET_NOT_FOUND"), + # Bad hash algorithm in from-hash + ( + "post", + "/api/assets/from-hash", + {"hash": "sha256:" + "0" * 64, "name": "x.bin", "tags": ["models", "checkpoints", "unit-tests"]}, + 400, + "INVALID_BODY", + ), + # Get with bad UUID format + ("get", "/api/assets/not-a-uuid", None, 404, None), + # Get content with bad UUID format + ("get", "/api/assets/not-a-uuid/content", None, 404, None), + ], + ids=["delete_nonexistent", "bad_hash_algorithm", "get_bad_uuid", "content_bad_uuid"], +) +def test_error_responses( + http: requests.Session, api_base: str, method, endpoint_template, payload, expected_status, error_code +): + # Replace {uuid} placeholder with a random UUID for delete test + endpoint = endpoint_template.replace("{uuid}", str(uuid.uuid4())) + url = f"{api_base}{endpoint}" + + if method == "get": + r = http.get(url, timeout=120) + elif method == "post": + r = http.post(url, json=payload, timeout=120) + elif method == "delete": + r = http.delete(url, timeout=120) + + assert r.status_code == expected_status + if error_code: + body = r.json() + assert body["error"]["code"] == error_code + + +def test_create_from_hash_invalid_json(http: requests.Session, api_base: str): + """Invalid JSON body requires special handling (data= instead of json=).""" + r = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) body = r.json() - assert r.status_code == 404 - assert body["error"]["code"] == "ASSET_NOT_FOUND" - - -def test_create_from_hash_invalids(http: requests.Session, api_base: str): - # Bad hash algorithm - bad = { - "hash": "sha256:" + "0" * 64, - "name": "x.bin", - "tags": ["models", "checkpoints", "unit-tests"], - } - r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_BODY" - - # Invalid JSON body - r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_JSON" - - -def test_get_update_download_bad_ids(http: requests.Session, api_base: str): - # All endpoints should be not found, as we UUID regex directly in the route definition. - bad_id = "not-a-uuid" - - r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120) - assert r1.status_code == 404 - - r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120) - assert r3.status_code == 404 + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_JSON" def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict): diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py index cdebf9082..672ba9728 100644 --- a/tests-unit/assets_test/test_downloads.py +++ b/tests-unit/assets_test/test_downloads.py @@ -6,7 +6,7 @@ from typing import Optional import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict): @@ -117,7 +117,7 @@ def test_download_missing_file_returns_404( assert body["error"]["code"] == "FILE_NOT_FOUND" finally: # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. - dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + dr = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) dr.content diff --git a/tests-unit/assets_test/test_file_utils.py b/tests-unit/assets_test/test_file_utils.py new file mode 100644 index 000000000..e3591d49b --- /dev/null +++ b/tests-unit/assets_test/test_file_utils.py @@ -0,0 +1,121 @@ +import os +import sys + +import pytest + +from app.assets.services.file_utils import is_visible, list_files_recursively + + +class TestIsVisible: + def test_visible_file(self): + assert is_visible("file.txt") is True + + def test_hidden_file(self): + assert is_visible(".hidden") is False + + def test_hidden_directory(self): + assert is_visible(".git") is False + + def test_visible_directory(self): + assert is_visible("src") is True + + def test_dotdot_is_hidden(self): + assert is_visible("..") is False + + def test_dot_is_hidden(self): + assert is_visible(".") is False + + +class TestListFilesRecursively: + def test_skips_hidden_files(self, tmp_path): + (tmp_path / "visible.txt").write_text("a") + (tmp_path / ".hidden").write_text("b") + + result = list_files_recursively(str(tmp_path)) + + assert len(result) == 1 + assert result[0].endswith("visible.txt") + + def test_skips_hidden_directories(self, tmp_path): + hidden_dir = tmp_path / ".hidden_dir" + hidden_dir.mkdir() + (hidden_dir / "file.txt").write_text("a") + + visible_dir = tmp_path / "visible_dir" + visible_dir.mkdir() + (visible_dir / "file.txt").write_text("b") + + result = list_files_recursively(str(tmp_path)) + + assert len(result) == 1 + assert "visible_dir" in result[0] + assert ".hidden_dir" not in result[0] + + def test_empty_directory(self, tmp_path): + result = list_files_recursively(str(tmp_path)) + assert result == [] + + def test_nonexistent_directory(self, tmp_path): + result = list_files_recursively(str(tmp_path / "nonexistent")) + assert result == [] + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_follows_symlinked_directories(self, tmp_path): + target = tmp_path / "real_dir" + target.mkdir() + (target / "model.safetensors").write_text("data") + + root = tmp_path / "root" + root.mkdir() + (root / "link").symlink_to(target) + + result = list_files_recursively(str(root)) + + assert len(result) == 1 + assert result[0].endswith("model.safetensors") + assert "link" in result[0] + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_follows_symlinked_files(self, tmp_path): + real_file = tmp_path / "real.txt" + real_file.write_text("content") + + root = tmp_path / "root" + root.mkdir() + (root / "link.txt").symlink_to(real_file) + + result = list_files_recursively(str(root)) + + assert len(result) == 1 + assert result[0].endswith("link.txt") + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_circular_symlinks_do_not_loop(self, tmp_path): + dir_a = tmp_path / "a" + dir_a.mkdir() + (dir_a / "file.txt").write_text("a") + # a/b -> a (circular) + (dir_a / "b").symlink_to(dir_a) + + result = list_files_recursively(str(dir_a)) + + assert len(result) == 1 + assert result[0].endswith("file.txt") + + @pytest.mark.skipif(sys.platform == "win32", reason="symlinks need privileges on Windows") + def test_mutual_circular_symlinks(self, tmp_path): + dir_a = tmp_path / "a" + dir_b = tmp_path / "b" + dir_a.mkdir() + dir_b.mkdir() + (dir_a / "file_a.txt").write_text("a") + (dir_b / "file_b.txt").write_text("b") + # a/link_b -> b and b/link_a -> a + (dir_a / "link_b").symlink_to(dir_b) + (dir_b / "link_a").symlink_to(dir_a) + + result = list_files_recursively(str(dir_a)) + basenames = sorted(os.path.basename(p) for p in result) + + assert "file_a.txt" in basenames + assert "file_b.txt" in basenames diff --git a/tests-unit/assets_test/test_list_filter.py b/tests-unit/assets_test/test_list_filter.py index 82e109832..dcb7a73ca 100644 --- a/tests-unit/assets_test/test_list_filter.py +++ b/tests-unit/assets_test/test_list_filter.py @@ -1,6 +1,7 @@ import time import uuid +import pytest import requests @@ -283,30 +284,21 @@ def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asse assert b2["has_more"] is False -def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): - r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_QUERY" - - r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_QUERY" - - -def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str): - # limit too small - r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120) - b1 = r1.json() - assert r1.status_code == 400 - assert b1["error"]["code"] == "INVALID_QUERY" - - # bad metadata JSON - r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120) - b2 = r2.json() - assert r2.status_code == 400 - assert b2["error"]["code"] == "INVALID_QUERY" +@pytest.mark.parametrize( + "params,error_code", + [ + ({"offset": "-1"}, "INVALID_QUERY"), + ({"limit": "abc"}, "INVALID_QUERY"), + ({"limit": "0"}, "INVALID_QUERY"), + ({"metadata_filter": "{not json"}, "INVALID_QUERY"), + ], + ids=["negative_offset", "non_int_limit", "zero_limit", "invalid_metadata_json"], +) +def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str, params, error_code): + r = http.get(api_base + "/api/assets", params=params, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == error_code def test_list_assets_name_contains_literal_underscore( diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py index f602e5a77..1fbd4d4e2 100644 --- a/tests-unit/assets_test/test_prune_orphaned_assets.py +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -3,7 +3,7 @@ from pathlib import Path import pytest import requests -from conftest import get_asset_filename, trigger_sync_seed_assets +from helpers import get_asset_filename, trigger_sync_seed_assets @pytest.fixture diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py new file mode 100644 index 000000000..94cc255bc --- /dev/null +++ b/tests-unit/assets_test/test_sync_references.py @@ -0,0 +1,482 @@ +"""Tests for sync_references_with_filesystem in scanner.py.""" + +import os +import tempfile +from datetime import datetime +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceTag, + Base, + Tag, +) +from app.assets.database.queries.asset_reference import ( + bulk_insert_references_ignore_conflicts, + get_references_for_prefixes, + get_unenriched_references, + restore_references_by_paths, +) +from app.assets.scanner import sync_references_with_filesystem +from app.assets.services.file_utils import get_mtime_ns + + +@pytest.fixture +def db_engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str: + """Create a file and return its absolute path (no symlink resolution).""" + p = temp_dir / name + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(content) + return os.path.abspath(str(p)) + + +def _stat_mtime_ns(path: str) -> int: + return get_mtime_ns(os.stat(path, follow_symlinks=True)) + + +def _make_asset( + session: Session, + asset_id: str, + file_path: str, + ref_id: str, + *, + asset_hash: str | None = None, + size_bytes: int = 100, + mtime_ns: int | None = None, + needs_verify: bool = False, + is_missing: bool = False, +) -> tuple[Asset, AssetReference]: + """Insert an Asset + AssetReference and flush.""" + asset = session.get(Asset, asset_id) + if asset is None: + asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes) + session.add(asset) + session.flush() + + ref = AssetReference( + id=ref_id, + asset_id=asset_id, + name=f"test-{ref_id}", + owner_id="system", + file_path=file_path, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + is_missing=is_missing, + ) + session.add(ref) + session.flush() + return asset, ref + + +def _ensure_missing_tag(session: Session): + """Ensure the 'missing' tag exists.""" + if not session.get(Tag, "missing"): + session.add(Tag(name="missing", tag_type="system")) + session.flush() + + +class _VerifyCase: + def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify): + self.id = id + self.stat_unchanged = stat_unchanged + self.needs_verify_before = needs_verify_before + self.expect_needs_verify = expect_needs_verify + + +VERIFY_CASES = [ + _VerifyCase( + id="unchanged_clears_verify", + stat_unchanged=True, + needs_verify_before=True, + expect_needs_verify=False, + ), + _VerifyCase( + id="unchanged_keeps_clear", + stat_unchanged=True, + needs_verify_before=False, + expect_needs_verify=False, + ), + _VerifyCase( + id="changed_sets_verify", + stat_unchanged=False, + needs_verify_before=False, + expect_needs_verify=True, + ), + _VerifyCase( + id="changed_keeps_verify", + stat_unchanged=False, + needs_verify_before=True, + expect_needs_verify=True, + ), +] + + +@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id) +def test_needs_verify_toggling(session, temp_dir, case): + """needs_verify is set/cleared based on mtime+size match.""" + fp = _create_file(temp_dir, "model.bin") + real_mtime = _stat_mtime_ns(fp) + + mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1 + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", + mtime_ns=mtime_for_db, + needs_verify=case.needs_verify_before, + ) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.needs_verify is case.expect_needs_verify + + +class _MissingCase: + def __init__(self, id, file_exists, expect_is_missing): + self.id = id + self.file_exists = file_exists + self.expect_is_missing = expect_is_missing + + +MISSING_CASES = [ + _MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False), + _MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True), +] + + +@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id) +def test_is_missing_flag(session, temp_dir, case): + """is_missing reflects whether the file exists on disk.""" + if case.file_exists: + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + else: + fp = str(temp_dir / "gone.bin") + mtime = 999 + + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is case.expect_is_missing + + +def test_seed_asset_all_missing_deletes_asset(session, temp_dir): + """Seed asset with all refs missing gets deleted entirely.""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + assert session.get(Asset, "seed1") is None + assert session.get(AssetReference, "r1") is None + + +def test_seed_asset_some_exist_returns_survivors(session, temp_dir): + """Seed asset with at least one existing ref survives and is returned.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert session.get(Asset, "seed1") is not None + assert os.path.abspath(fp) in survivors + + +def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir): + """Hashed asset with one stat-unchanged ref deletes missing refs.""" + fp_ok = _create_file(temp_dir, "good.bin") + fp_gone = str(temp_dir / "gone.bin") + mtime = _stat_mtime_ns(fp_ok) + + _make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime) + # Second ref on same asset, file missing + ref_gone = AssetReference( + id="r_gone", asset_id="h1", name="gone", + owner_id="system", file_path=fp_gone, mtime_ns=999, + ) + session.add(ref_gone) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r_ok") is not None + assert session.get(AssetReference, "r_gone") is None + + +def test_hashed_asset_all_missing_keeps_refs(session, temp_dir): + """Hashed asset with all refs missing keeps refs (no pruning).""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r1") is not None + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True + + +def test_missing_tag_added_when_all_refs_gone(session, temp_dir): + """Missing tag is added to hashed asset when all refs are missing.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is not None + + +def test_missing_tag_removed_when_ref_ok(session, temp_dir): + """Missing tag is removed from hashed asset when a ref is stat-unchanged.""" + _ensure_missing_tag(session) + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime) + # Pre-add a stale missing tag + session.add(AssetReferenceTag( + asset_reference_id="r1", tag_name="missing", origin="automatic", + )) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None + + +def test_missing_tags_not_touched_when_flag_false(session, temp_dir): + """Missing tags are not modified when update_missing_tags=False.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=False, + ) + session.commit() + + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None # tag was never added + + +def test_returns_none_when_collect_false(session, temp_dir): + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=False, + ) + + assert result is None + + +def test_returns_empty_set_for_no_prefixes(session): + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + + assert result == set() + + +def test_no_references_is_noop(session, temp_dir): + """No crash and no side effects when there are no references.""" + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert survivors == set() + + +# --------------------------------------------------------------------------- +# Soft-delete persistence across scanner operations +# --------------------------------------------------------------------------- + +def _soft_delete_ref(session: Session, ref_id: str) -> None: + """Mark a reference as soft-deleted (mimics the API DELETE behaviour).""" + ref = session.get(AssetReference, ref_id) + ref.deleted_at = datetime(2025, 1, 1) + session.flush() + + +def test_soft_deleted_ref_excluded_from_get_references_for_prefixes(session, temp_dir): + """get_references_for_prefixes skips soft-deleted references.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + rows = get_references_for_prefixes(session, [str(temp_dir)], include_missing=True) + assert len(rows) == 0 + + +def test_sync_does_not_resurrect_soft_deleted_ref(session, temp_dir): + """Scanner sync leaves soft-deleted refs untouched even when file exists on disk.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.deleted_at is not None, "soft-deleted ref must stay deleted after sync" + + +def test_bulk_insert_does_not_overwrite_soft_deleted_ref(session, temp_dir): + """bulk_insert_references_ignore_conflicts cannot replace a soft-deleted row.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + now = datetime.now(tz=None) + bulk_insert_references_ignore_conflicts(session, [ + { + "id": "r_new", + "asset_id": "a1", + "file_path": fp, + "name": "model.bin", + "owner_id": "", + "mtime_ns": mtime, + "preview_id": None, + "user_metadata": None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ]) + session.commit() + + session.expire_all() + # Original row is still the soft-deleted one + ref = session.get(AssetReference, "r1") + assert ref is not None + assert ref.deleted_at is not None + # The new row was not inserted (conflict on file_path) + assert session.get(AssetReference, "r_new") is None + + +def test_restore_references_by_paths_skips_soft_deleted(session, temp_dir): + """restore_references_by_paths does not clear is_missing on soft-deleted refs.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", mtime_ns=mtime, is_missing=True, + ) + _soft_delete_ref(session, "r1") + session.commit() + + restored = restore_references_by_paths(session, [fp]) + session.commit() + + assert restored == 0 + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True, "is_missing must not be cleared on soft-deleted ref" + assert ref.deleted_at is not None + + +def test_get_unenriched_references_excludes_soft_deleted(session, temp_dir): + """Enrichment queries do not pick up soft-deleted references.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + _soft_delete_ref(session, "r1") + session.commit() + + rows = get_unenriched_references(session, [str(temp_dir)], max_level=2) + assert len(rows) == 0 + + +def test_sync_ignores_soft_deleted_seed_asset(session, temp_dir): + """Soft-deleted seed ref is not garbage-collected even when file is missing.""" + fp = str(temp_dir / "gone.bin") # file does not exist + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + _soft_delete_ref(session, "r1") + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + # Asset and ref must still exist — scanner did not see the soft-deleted row + assert session.get(Asset, "seed1") is not None + assert session.get(AssetReference, "r1") is not None diff --git a/tests-unit/assets_test/test_tags.py b/tests-unit/assets_test/test_tags_api.py similarity index 98% rename from tests-unit/assets_test/test_tags.py rename to tests-unit/assets_test/test_tags_api.py index 6b1047802..595bf29c6 100644 --- a/tests-unit/assets_test/test_tags.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -69,8 +69,8 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, used_names = [t["name"] for t in body2["tags"]] assert custom_tag in used_names - # Delete the asset so the tag usage drops to zero - rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) + # Hard-delete the asset so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}?delete_content=true", timeout=120) assert rd.status_code == 204 # Now the custom tag must not be returned when include_zero=false diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py index 137d7391a..d68e5b5d7 100644 --- a/tests-unit/assets_test/test_uploads.py +++ b/tests-unit/assets_test/test_uploads.py @@ -18,25 +18,25 @@ def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, ma assert r1.status_code == 201, a1 assert a1["created_new"] is True - # Second upload with the same data and name should return created_new == False and the same asset + # Second upload with the same data and name creates a new AssetReference (duplicates allowed) + # Returns 200 because Asset already exists, but a new AssetReference is created files = {"file": (name, data, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) a2 = r2.json() - assert r2.status_code == 200, a2 - assert a2["created_new"] is False + assert r2.status_code in (200, 201), a2 assert a2["asset_hash"] == a1["asset_hash"] - assert a2["id"] == a1["id"] # old reference + assert a2["id"] != a1["id"] # new reference with same content - # Third upload with the same data but new name should return created_new == False and the new AssetReference + # Third upload with the same data but different name also creates new AssetReference files = {"file": (name, data, "application/octet-stream")} form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)} - r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) - a3 = r2.json() - assert r2.status_code == 200, a3 - assert a3["created_new"] is False + r3 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a3 = r3.json() + assert r3.status_code in (200, 201), a3 assert a3["asset_hash"] == a1["asset_hash"] - assert a3["id"] != a1["id"] # old reference + assert a3["id"] != a1["id"] + assert a3["id"] != a2["id"] def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): @@ -116,7 +116,7 @@ def test_concurrent_upload_identical_bytes_different_names( ): """ Two concurrent uploads of identical bytes but different names. - Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + Expect a single Asset (same hash), two AssetReference rows, and exactly one created_new=True. """ scope = f"concupload-{uuid.uuid4().hex[:6]}" name1, name2 = "cu_a.bin", "cu_b.bin" diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 2355b8000..3a6790ee0 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -2,4 +2,3 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio websocket-client -blake3 diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py new file mode 100644 index 000000000..db3795e48 --- /dev/null +++ b/tests-unit/seeder_test/test_seeder.py @@ -0,0 +1,900 @@ +"""Unit tests for the _AssetSeeder background scanning class.""" + +import threading +from unittest.mock import patch + +import pytest + +from app.assets.database.queries.asset_reference import UnenrichedReferenceRow +from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State + + +@pytest.fixture +def fresh_seeder(): + """Create a fresh _AssetSeeder instance for testing.""" + seeder = _AssetSeeder() + yield seeder + seeder.shutdown(timeout=1.0) + + +@pytest.fixture +def mock_dependencies(): + """Mock all external dependencies for isolated testing.""" + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + yield + + +class TestSeederStateTransitions: + """Test state machine transitions.""" + + def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder): + assert fresh_seeder.get_status().state == State.IDLE + + def test_start_transitions_to_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + started = fresh_seeder.start(roots=("models",)) + assert started is True + assert reached.wait(timeout=2.0) + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_start_while_running_returns_false( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + second_start = fresh_seeder.start(roots=("models",)) + assert second_start is False + + barrier.set() + + def test_cancel_transitions_to_cancelling( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + cancelled = fresh_seeder.cancel() + assert cancelled is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): + cancelled = fresh_seeder.cancel() + assert cancelled is False + + def test_state_returns_to_idle_after_completion( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + +class TestSeederWait: + """Test wait() behavior.""" + + def test_wait_blocks_until_complete( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=5.0) + assert completed is True + assert fresh_seeder.get_status().state == State.IDLE + + def test_wait_returns_false_on_timeout( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=10.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + completed = fresh_seeder.wait(timeout=0.1) + assert completed is False + + barrier.set() + + def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder): + completed = fresh_seeder.wait(timeout=1.0) + assert completed is True + + +class TestSeederProgress: + """Test progress tracking.""" + + def test_get_status_returns_progress_during_scan( + self, fresh_seeder: _AssetSeeder + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_build(*args, **kwargs): + reached.set() + barrier.wait(timeout=5.0) + return ([], set(), 0) + + paths = ["/path/file1.safetensors", "/path/file2.safetensors"] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch("app.assets.seeder.build_asset_specs", side_effect=slow_build), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + status = fresh_seeder.get_status() + assert status.state == State.RUNNING + assert status.progress is not None + assert status.progress.total == 2 + + barrier.set() + + def test_progress_callback_is_invoked( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + progress_updates: list[Progress] = [] + + def callback(p: Progress): + progress_updates.append(p) + + with patch( + "app.assets.seeder.collect_paths_for_roots", + return_value=[f"/path/file{i}.safetensors" for i in range(10)], + ): + fresh_seeder.start(roots=("models",), progress_callback=callback) + fresh_seeder.wait(timeout=5.0) + + assert len(progress_updates) > 0 + + +class TestSeederCancellation: + """Test cancellation behavior.""" + + def test_scan_commits_partial_progress_on_cancellation( + self, fresh_seeder: _AssetSeeder + ): + insert_count = 0 + barrier = threading.Event() + first_insert_done = threading.Event() + + def slow_insert(specs, tags): + nonlocal insert_count + insert_count += 1 + if insert_count == 1: + first_insert_done.set() + if insert_count >= 2: + barrier.wait(timeout=5.0) + return len(specs) + + paths = [f"/path/file{i}.safetensors" for i in range(1500)] + specs = [ + { + "abs_path": p, + "size_bytes": 100, + "mtime_ns": 0, + "info_name": f"file{i}", + "tags": [], + "fname": f"file{i}", + "metadata": None, + "hash": None, + "mime_type": None, + } + for i, p in enumerate(paths) + ] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), + patch( + "app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0) + ), + patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + assert first_insert_done.wait(timeout=2.0) + + fresh_seeder.cancel() + barrier.set() + fresh_seeder.wait(timeout=5.0) + + assert 1 <= insert_count < 3 # 1500 paths / 500 batch = 3; cancel stopped early + + +class TestSeederErrorHandling: + """Test error handling behavior.""" + + def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch( + "app.assets.seeder.collect_paths_for_roots", + return_value=["/path/file.safetensors"], + ), + patch( + "app.assets.seeder.build_asset_specs", + return_value=( + [ + { + "abs_path": "/path/file.safetensors", + "size_bytes": 100, + "mtime_ns": 0, + "info_name": "file", + "tags": [], + "fname": "file", + "metadata": None, + "hash": None, + "mime_type": None, + } + ], + set(), + 0, + ), + ), + patch( + "app.assets.seeder.insert_asset_specs", + side_effect=Exception("DB connection failed"), + ), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "DB connection failed" in status.errors[0] + + def test_dependencies_unavailable_captured_in_errors( + self, fresh_seeder: _AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert len(status.errors) > 0 + assert "dependencies" in status.errors[0].lower() + + def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder.sync_root_safely", + side_effect=RuntimeError("Unexpected crash"), + ), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + status = fresh_seeder.get_status() + assert status.state == State.IDLE + assert len(status.errors) > 0 + + +class TestSeederThreadSafety: + """Test thread safety of concurrent operations.""" + + def test_concurrent_start_calls_spawn_only_one_thread( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + + def slow_collect(*args): + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + results = [] + + def try_start(): + results.append(fresh_seeder.start(roots=("models",))) + + threads = [threading.Thread(target=try_start) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + barrier.set() + + assert sum(results) == 1 + + def test_get_status_safe_during_scan( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + statuses = [] + for _ in range(100): + statuses.append(fresh_seeder.get_status()) + + barrier.set() + + assert all( + s.state in (State.RUNNING, State.IDLE, State.CANCELLING) + for s in statuses + ) + + +class TestSeederMarkMissing: + """Test mark_missing_outside_prefixes behavior.""" + + def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder): + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch( + "app.assets.seeder.get_all_known_prefixes", + return_value=["/models", "/input", "/output"], + ), + patch( + "app.assets.seeder.mark_missing_outside_prefixes_safely", return_value=5 + ) as mock_mark, + ): + result = fresh_seeder.mark_missing_outside_prefixes() + assert result == 5 + mock_mark.assert_called_once_with(["/models", "/input", "/output"]) + + def test_mark_missing_raises_when_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + with pytest.raises(ScanInProgressError): + fresh_seeder.mark_missing_outside_prefixes() + + barrier.set() + + def test_mark_missing_returns_zero_when_dependencies_unavailable( + self, fresh_seeder: _AssetSeeder + ): + with patch("app.assets.seeder.dependencies_available", return_value=False): + result = fresh_seeder.mark_missing_outside_prefixes() + assert result == 0 + + def test_prune_first_flag_triggers_mark_missing_before_scan( + self, fresh_seeder: _AssetSeeder + ): + call_order = [] + + def track_mark(prefixes): + call_order.append("mark_missing") + return 3 + + def track_sync(root): + call_order.append(f"sync_{root}") + return set() + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]), + patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark), + patch("app.assets.seeder.sync_root_safely", side_effect=track_sync), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",), prune_first=True) + fresh_seeder.wait(timeout=5.0) + + assert call_order[0] == "mark_missing" + assert "sync_models" in call_order + + +class TestSeederPhases: + """Test phased scanning behavior.""" + + def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder): + """Verify start_fast only runs the fast phase.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start_fast(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 1 + assert len(enrich_called) == 0 + + def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder): + """Verify start_enrich only runs the enrich phase.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start_enrich(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 0 + assert len(enrich_called) == 1 + + def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder): + """Verify full scan runs both fast and enrich phases.""" + fast_called = [] + enrich_called = [] + + def track_fast(*args, **kwargs): + fast_called.append(True) + return ([], set(), 0) + + def track_enrich(*args, **kwargs): + enrich_called.append(True) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.FULL) + fresh_seeder.wait(timeout=5.0) + + assert len(fast_called) == 1 + assert len(enrich_called) == 1 + + +class TestSeederPauseResume: + """Test pause/resume behavior.""" + + def test_pause_transitions_to_paused( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + paused = fresh_seeder.pause() + assert paused is True + assert fresh_seeder.get_status().state == State.PAUSED + + barrier.set() + + def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): + paused = fresh_seeder.pause() + assert paused is False + + def test_resume_returns_to_running( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + resumed = fresh_seeder.resume() + assert resumed is True + assert fresh_seeder.get_status().state == State.RUNNING + + barrier.set() + + def test_resume_when_not_paused_returns_false( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + resumed = fresh_seeder.resume() + assert resumed is False + + barrier.set() + + def test_cancel_while_paused_works( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached_checkpoint = threading.Event() + + def slow_collect(*args): + reached_checkpoint.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached_checkpoint.wait(timeout=2.0) + + fresh_seeder.pause() + assert fresh_seeder.get_status().state == State.PAUSED + + cancelled = fresh_seeder.cancel() + assert cancelled is True + + barrier.set() + fresh_seeder.wait(timeout=5.0) + assert fresh_seeder.get_status().state == State.IDLE + +class TestSeederStopRestart: + """Test stop and restart behavior.""" + + def test_stop_is_alias_for_cancel( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + stopped = fresh_seeder.stop() + assert stopped is True + assert fresh_seeder.get_status().state == State.CANCELLING + + barrier.set() + + def test_restart_cancels_and_starts_new_scan( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + barrier = threading.Event() + reached = threading.Event() + start_count = 0 + + def slow_collect(*args): + nonlocal start_count + start_count += 1 + if start_count == 1: + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",)) + assert reached.wait(timeout=2.0) + + barrier.set() + restarted = fresh_seeder.restart() + assert restarted is True + + fresh_seeder.wait(timeout=5.0) + assert start_count == 2 + + def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder): + """Verify restart uses previous params when not overridden.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("input", "output")) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart() + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("input", "output") + assert collected_roots[1] == ("input", "output") + + def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder): + """Verify restart can override previous params.""" + collected_roots = [] + + def track_collect(roots): + collected_roots.append(roots) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + fresh_seeder.start(roots=("models",)) + fresh_seeder.wait(timeout=5.0) + + fresh_seeder.restart(roots=("input",)) + fresh_seeder.wait(timeout=5.0) + + assert len(collected_roots) == 2 + assert collected_roots[0] == ("models",) + assert collected_roots[1] == ("input",) + + +def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow: + return UnenrichedReferenceRow( + reference_id=ref_id, asset_id=asset_id, + file_path=f"/fake/{ref_id}.bin", enrichment_level=0, + ) + + +class TestEnrichPhaseDefensiveLogic: + """Test skip_ids filtering and consecutive_empty termination.""" + + def test_failed_refs_are_skipped_on_subsequent_batches( + self, fresh_seeder: _AssetSeeder, + ): + """References that fail enrichment are filtered out of future batches.""" + row_a = _make_row("r1") + row_b = _make_row("r2") + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + return [row_a, row_b] + return [] + + enriched_refs: list[list[str]] = [] + + def fake_enrich(rows, **kwargs): + ref_ids = [r.reference_id for r in rows] + enriched_refs.append(ref_ids) + # r1 always fails, r2 succeeds + failed = [r.reference_id for r in rows if r.reference_id == "r1"] + enriched = len(rows) - len(failed) + return enriched, failed + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # First batch: both refs attempted + assert "r1" in enriched_refs[0] + assert "r2" in enriched_refs[0] + # Second batch: r1 filtered out + assert "r1" not in enriched_refs[1] + assert "r2" in enriched_refs[1] + + def test_stops_after_consecutive_empty_batches( + self, fresh_seeder: _AssetSeeder, + ): + """Enrich phase terminates after 3 consecutive batches with zero progress.""" + row = _make_row("r1") + batch_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal batch_count + batch_count += 1 + # Always return the same row (simulating a permanently failing ref) + return [row] + + def fake_enrich(rows, **kwargs): + # Always fail — zero enriched, all failed + return 0, [r.reference_id for r in rows] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # Should stop after exactly 3 consecutive empty batches + # Batch 1: returns row, enrich fails → filtered out in batch 2+ + # But get_unenriched keeps returning it, filter removes it → empty → break + # Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row], + # skip_ids filters it → empty list → breaks via `if not unenriched: break` + # So it terminates in 2 calls to get_unenriched. + assert batch_count == 2 + + def test_consecutive_empty_counter_resets_on_success( + self, fresh_seeder: _AssetSeeder, + ): + """A successful batch resets the consecutive empty counter.""" + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 6: + return [_make_row(f"r{call_count}", f"a{call_count}")] + return [] + + def fake_enrich(rows, **kwargs): + ref_id = rows[0].reference_id + # Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6 + if ref_id in ("r1", "r2", "r4", "r5"): + return 0, [ref_id] + return 1, [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # All 6 batches should run + 1 final call returning empty + assert call_count == 7 + status = fresh_seeder.get_status() + assert status.state == State.IDLE diff --git a/utils/mime_types.py b/utils/mime_types.py new file mode 100644 index 000000000..916e963c5 --- /dev/null +++ b/utils/mime_types.py @@ -0,0 +1,37 @@ +"""Centralized MIME type initialization. + +Call init_mime_types() once at startup to initialize the MIME type database +and register all custom types used across ComfyUI. +""" + +import mimetypes + +_initialized = False + + +def init_mime_types(): + """Initialize the MIME type database and register all custom types. + + Safe to call multiple times; only runs once. + """ + global _initialized + if _initialized: + return + _initialized = True + + mimetypes.init() + + # Web types (used by server.py for static file serving) + mimetypes.add_type('application/javascript; charset=utf-8', '.js') + mimetypes.add_type('image/webp', '.webp') + + # Model and data file types (used by asset scanning / metadata extraction) + mimetypes.add_type("application/safetensors", ".safetensors") + mimetypes.add_type("application/safetensors", ".sft") + mimetypes.add_type("application/pytorch", ".pt") + mimetypes.add_type("application/pytorch", ".pth") + mimetypes.add_type("application/pickle", ".ckpt") + mimetypes.add_type("application/pickle", ".pkl") + mimetypes.add_type("application/gguf", ".gguf") + mimetypes.add_type("application/yaml", ".yaml") + mimetypes.add_type("application/yaml", ".yml") From 7723f20bbe010a3ea4eac602f77b0ff496f123c4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 8 Mar 2026 13:17:40 -0700 Subject: [PATCH 37/45] comfy-aimdo 0.2.9 (#12840) Comfy-aimdo 0.2.9 fixes a context issue where if a non-main thread does a spurious garbage collection, cudaFrees are attempted with bad context. Some new APIs for displaying aimdo stats in UI widgets are also added. These are purely additive getters that dont touch cuda APIs. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9527135ec..b1db1cf24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.7 +comfy-aimdo>=0.2.9 requests simpleeval>=1.0.0 blake3 From e4b0bb8305a4069ef7ff8396bfc6057c736ab95b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 8 Mar 2026 13:25:30 -0700 Subject: [PATCH 38/45] Import assets seeder later, print some package versions. (#12841) --- app/assets/services/hashing.py | 6 +++++- main.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py index 92aee6402..41d8b4615 100644 --- a/app/assets/services/hashing.py +++ b/app/assets/services/hashing.py @@ -3,8 +3,12 @@ import os from contextlib import contextmanager from dataclasses import dataclass from typing import IO, Any, Callable, Iterator +import logging -from blake3 import blake3 +try: + from blake3 import blake3 +except ModuleNotFoundError: + logging.warning("WARNING: blake3 package not installed") DEFAULT_CHUNK = 8 * 1024 * 1024 diff --git a/main.py b/main.py index a8fc1a28d..1977f9362 100644 --- a/main.py +++ b/main.py @@ -3,11 +3,11 @@ comfy.options.enable_args_parsing() import os import importlib.util +import importlib.metadata import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger -from app.assets.seeder import asset_seeder import itertools import utils.extra_config from utils.mime_types import init_mime_types @@ -182,6 +182,7 @@ if 'torch' in sys.modules: import comfy.utils +from app.assets.seeder import asset_seeder import execution import server @@ -451,6 +452,11 @@ if __name__ == "__main__": # Running directly, just start ComfyUI. logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) + for package in ("comfy-aimdo", "comfy-kitchen"): + try: + logging.info("{} version: {}".format(package, importlib.metadata.version(package))) + except: + pass if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") From 06f85e2c792c626f2cab3cb4f94cd30d43e9347b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 9 Mar 2026 22:08:51 +0200 Subject: [PATCH 39/45] Fix text encoder lora loading for wrapped models (#12852) --- comfy/lora.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index f36ddb046..63ee85323 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models + if tp > 0 and not k.startswith("clip_"): + key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False From 814dab9f4636df22a36cbbad21e35ac7609a0ef2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 10 Mar 2026 10:03:22 +0800 Subject: [PATCH 40/45] Update workflow templates to v0.9.18 (#12857) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b1db1cf24..bb58f8d01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.11 +comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch torchsde From 740d998c9cc821ca0a72b5b5d4b17aba1aec6b44 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:49:31 +0900 Subject: [PATCH 41/45] fix(manager): improve install guidance when comfyui-manager is not installed (#12810) --- main.py | 13 ++++++++++--- manager_requirements.txt | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 1977f9362..83a7244db 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ comfy.options.enable_args_parsing() import os import importlib.util +import shutil import importlib.metadata import folder_paths import time @@ -64,8 +65,15 @@ if __name__ == "__main__": def handle_comfyui_manager_unavailable(): - if not args.windows_standalone_build: - logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") + uv_available = shutil.which("uv") is not None + + pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}" + msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}" + if uv_available: + msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}" + msg += "\n" + logging.warning(msg) args.enable_manager = False @@ -173,7 +181,6 @@ execute_prestartup_script() # Main code import asyncio -import shutil import threading import gc diff --git a/manager_requirements.txt b/manager_requirements.txt index c420cc48e..6bcc3fb50 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b1 +comfyui_manager==4.1b2 \ No newline at end of file From c4fb0271cd7fbddb2381372b1f7c1206d1dd58fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:37:58 -0700 Subject: [PATCH 42/45] Add a way for nodes to add pre attn patches to flux model. (#12861) --- comfy/ldm/flux/layers.py | 15 ++++++++++++++- comfy/ldm/flux/math.py | 2 ++ comfy/ldm/flux/model.py | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8b3f500d7..e20d498f8 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module): del txt_k, img_k v = torch.cat((txt_v, img_v), dim=2) del txt_v, img_v + + extra_options["img_slice"] = [txt.shape[1], q.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # run actual attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v if "attn1_output_patch" in transformer_patches: - extra_options["img_slice"] = [txt.shape[1], attn.shape[1]] patch = transformer_patches["attn1_output_patch"] for p in patch: attn = p(attn, extra_options) @@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module): del qkv q, k = self.norm(q, k, v) + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 5e764bb46..824daf5e6 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def _apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]: + freqs_cis = freqs_cis[:, :, :x_.shape[2]] x_out = freqs_cis[..., 0] * x_[..., 0] x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4dcf7c5..00f12c031 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -170,7 +170,7 @@ class Flux(nn.Module): if "post_input" in patches: for p in patches["post_input"]: - out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) img = out["img"] txt = out["txt"] img_ids = out["img_ids"] From a912809c252f5a2d69c8ab4035fc262a578fdcee Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:50:10 -0700 Subject: [PATCH 43/45] model_detection: deep clone pre edited edited weights (#12862) Deep clone these weights as needed to avoid segfaulting when it tries to touch the original mmap. --- comfy/model_detection.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6eace4628..35a6822e3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,4 +1,5 @@ import json +import comfy.memory_management import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): new[:old_weight.shape[0]] = old_weight old_weight = new + if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled: + old_weight = old_weight.clone() + w = old_weight.narrow(offset[0], offset[1], offset[2]) else: + if comfy.memory_management.aimdo_enabled: + weight = weight.clone() old_weight = weight w = weight w[:] = fun(weight) From 535c16ce6e3d2634d6eb2fd17ecccb8d497e26a0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Mar 2026 21:41:02 -0700 Subject: [PATCH 44/45] Widen OOM_EXCEPTION to AcceleratorError form (#12835) Pytorch only filters for OOMs in its own allocators however there are paths that can OOM on allocators made outside the pytorch allocators. These manifest as an AllocatorError as pytorch does not have universal error translation to its OOM type on exception. Handle it. A log I have for this also shows a double report of the error async, so call the async discarder to cleanup and make these OOMs look like OOMs. --- comfy/ldm/modules/attention.py | 3 ++- comfy/ldm/modules/diffusionmodules/model.py | 6 ++++-- comfy/ldm/modules/sub_quadratic_attention.py | 3 ++- comfy/model_management.py | 12 ++++++++++++ comfy/sd.py | 6 ++++-- comfy_extras/nodes_upscale_model.py | 3 ++- execution.py | 2 +- 7 files changed, 27 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 10d051325..b193fe5e8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) if first_op_done == False: model_management.soft_empty_cache(True) if cleared_cache == False: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 805592aa5..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -258,7 +258,8 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) model_management.soft_empty_cache(True) steps *= 2 if steps > 128: @@ -314,7 +315,8 @@ def pytorch_attention(q, k, v): try: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") oom_fallback = True if oom_fallback: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c..f982afc2b 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/model_management.py b/comfy/model_management.py index 07bc8ad67..81550c790 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,6 +270,18 @@ try: except: OOM_EXCEPTION = Exception +def is_oom(e): + if isinstance(e, OOM_EXCEPTION): + return True + if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + discard_cuda_async_error() + return True + return False + +def raise_non_oom(e): + if not is_oom(e): + raise e + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: diff --git a/comfy/sd.py b/comfy/sd.py index 888ef1e77..adcd67767 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -954,7 +954,8 @@ class VAE: if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -1029,7 +1030,8 @@ class VAE: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples[x:x + batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 97b9e948d..db4f9d231 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode): pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) tile //= 2 if tile < 128: raise e diff --git a/execution.py b/execution.py index 7ccdbf93e..a7791efed 100644 --- a/execution.py +++ b/execution.py @@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, logging.error(traceback.format_exc()) tips = "" - if isinstance(ex, comfy.model_management.OOM_EXCEPTION): + if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") From 8086468d2a1a5a6ed70fea3391e7fb9248ebc7da Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 10 Mar 2026 09:05:31 -0700 Subject: [PATCH 45/45] main: switch on faulthandler (#12868) When we get segfault bug reports we dont get much. Switch on pythons inbuilt tracer for segfault. --- main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.py b/main.py index 83a7244db..8905fd09a 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ from app.logger import setup_logger import itertools import utils.extra_config from utils.mime_types import init_mime_types +import faulthandler import logging import sys from comfy_execution.progress import get_progress_state @@ -26,6 +27,8 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +faulthandler.enable(file=sys.stderr, all_threads=False) + import comfy_aimdo.control if enables_dynamic_vram():