test: add unit tests for CacheProvider API

- Add comprehensive tests for _canonicalize deterministic ordering
- Add tests for serialize_cache_key hash consistency
- Add tests for contains_nan utility
- Add tests for estimate_value_size
- Add tests for provider registry (register, unregister, clear)
- Move json import to top-level (fix inline import)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Deep Mehta
2026-01-24 14:20:49 +05:30
parent e17571d9be
commit 5e4bbca1ad
2 changed files with 364 additions and 7 deletions

View File

@@ -24,11 +24,12 @@ Example usage:
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple, List
from dataclasses import dataclass
import logging
import threading
import hashlib
import pickle
import json
import logging
import math
import pickle
import threading
logger = logging.getLogger(__name__)
@@ -210,8 +211,6 @@ def _canonicalize(obj: Any) -> Any:
which is critical for cross-pod cache key consistency. Frozensets in particular
have non-deterministic iteration order between Python sessions.
"""
import json
if isinstance(obj, frozenset):
# Sort frozenset items for deterministic ordering
return ("__frozenset__", sorted(
@@ -252,8 +251,6 @@ def serialize_cache_key(cache_key: Any) -> bytes:
affecting frozenset iteration order. This is critical for distributed caching
where different pods need to compute the same hash for identical inputs.
"""
import json
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))

View File

@@ -0,0 +1,360 @@
"""Tests for external cache provider API."""
import math
import pytest
from unittest.mock import MagicMock, patch
from typing import Optional
from comfy_execution.cache_provider import (
CacheProvider,
CacheContext,
CacheValue,
register_cache_provider,
unregister_cache_provider,
get_cache_providers,
has_cache_providers,
clear_cache_providers,
serialize_cache_key,
contains_nan,
estimate_value_size,
_canonicalize,
)
class TestCanonicalize:
"""Test _canonicalize function for deterministic ordering."""
def test_frozenset_ordering_is_deterministic(self):
"""Frozensets should produce consistent canonical form regardless of iteration order."""
# Create two frozensets with same content
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_nested_frozenset_ordering(self):
"""Nested frozensets should also be deterministically ordered."""
inner1 = frozenset([1, 2, 3])
inner2 = frozenset([3, 2, 1])
fs1 = frozenset([("key", inner1)])
fs2 = frozenset([("key", inner2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_dict_ordering(self):
"""Dicts should be sorted by key."""
d1 = {"z": 1, "a": 2, "m": 3}
d2 = {"a": 2, "m": 3, "z": 1}
result1 = _canonicalize(d1)
result2 = _canonicalize(d2)
assert result1 == result2
def test_tuple_preserved(self):
"""Tuples should be marked and preserved."""
t = (1, 2, 3)
result = _canonicalize(t)
assert result[0] == "__tuple__"
assert result[1] == [1, 2, 3]
def test_list_preserved(self):
"""Lists should be recursively canonicalized."""
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
result = _canonicalize(lst)
# First element should be dict with sorted keys
assert result[0] == {"a": 1, "b": 2}
# Second element should be canonicalized frozenset
assert result[1][0] == "__frozenset__"
def test_primitives_unchanged(self):
"""Primitive types should pass through unchanged."""
assert _canonicalize(42) == 42
assert _canonicalize(3.14) == 3.14
assert _canonicalize("hello") == "hello"
assert _canonicalize(True) is True
assert _canonicalize(None) is None
def test_bytes_converted(self):
"""Bytes should be converted to hex string."""
b = b"\x00\xff"
result = _canonicalize(b)
assert result[0] == "__bytes__"
assert result[1] == "00ff"
def test_set_ordering(self):
"""Sets should be sorted like frozensets."""
s1 = {3, 1, 2}
s2 = {1, 2, 3}
result1 = _canonicalize(s1)
result2 = _canonicalize(s2)
assert result1 == result2
assert result1[0] == "__set__"
class TestSerializeCacheKey:
"""Test serialize_cache_key for deterministic hashing."""
def test_same_content_same_hash(self):
"""Same content should produce same hash."""
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
hash1 = serialize_cache_key(key1)
hash2 = serialize_cache_key(key2)
assert hash1 == hash2
def test_different_content_different_hash(self):
"""Different content should produce different hash."""
key1 = frozenset([("node_1", "value_a")])
key2 = frozenset([("node_1", "value_b")])
hash1 = serialize_cache_key(key1)
hash2 = serialize_cache_key(key2)
assert hash1 != hash2
def test_returns_bytes(self):
"""Should return bytes (SHA256 digest)."""
key = frozenset([("test", 123)])
result = serialize_cache_key(key)
assert isinstance(result, bytes)
assert len(result) == 32 # SHA256 produces 32 bytes
def test_complex_nested_structure(self):
"""Complex nested structures should hash deterministically."""
key = frozenset([
("node_1", frozenset([
("input_a", ("tuple", "value")),
("input_b", {"nested": "dict"}),
])),
("node_2", frozenset([
("param", 42),
])),
])
# Hash twice to verify determinism
hash1 = serialize_cache_key(key)
hash2 = serialize_cache_key(key)
assert hash1 == hash2
class TestContainsNan:
"""Test contains_nan utility function."""
def test_nan_float_detected(self):
"""NaN floats should be detected."""
assert contains_nan(float('nan')) is True
def test_regular_float_not_nan(self):
"""Regular floats should not be detected as NaN."""
assert contains_nan(3.14) is False
assert contains_nan(0.0) is False
assert contains_nan(-1.5) is False
def test_infinity_not_nan(self):
"""Infinity is not NaN."""
assert contains_nan(float('inf')) is False
assert contains_nan(float('-inf')) is False
def test_nan_in_list(self):
"""NaN in list should be detected."""
assert contains_nan([1, 2, float('nan'), 4]) is True
assert contains_nan([1, 2, 3, 4]) is False
def test_nan_in_tuple(self):
"""NaN in tuple should be detected."""
assert contains_nan((1, float('nan'))) is True
assert contains_nan((1, 2, 3)) is False
def test_nan_in_frozenset(self):
"""NaN in frozenset should be detected."""
assert contains_nan(frozenset([1, float('nan')])) is True
assert contains_nan(frozenset([1, 2, 3])) is False
def test_nan_in_dict_value(self):
"""NaN in dict value should be detected."""
assert contains_nan({"key": float('nan')}) is True
assert contains_nan({"key": 42}) is False
def test_nan_in_nested_structure(self):
"""NaN in deeply nested structure should be detected."""
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
assert contains_nan(nested) is True
def test_non_numeric_types(self):
"""Non-numeric types should not be NaN."""
assert contains_nan("string") is False
assert contains_nan(None) is False
assert contains_nan(True) is False
class TestEstimateValueSize:
"""Test estimate_value_size utility function."""
def test_empty_outputs(self):
"""Empty outputs should have zero size."""
value = CacheValue(outputs=[])
assert estimate_value_size(value) == 0
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_tensor_size_estimation(self):
"""Tensor size should be estimated correctly."""
import torch
# 1000 float32 elements = 4000 bytes
tensor = torch.zeros(1000, dtype=torch.float32)
value = CacheValue(outputs=[[tensor]])
size = estimate_value_size(value)
assert size == 4000
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_nested_tensor_in_dict(self):
"""Tensors nested in dicts should be counted."""
import torch
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
value = CacheValue(outputs=[[{"samples": tensor}]])
size = estimate_value_size(value)
assert size == 400
class TestProviderRegistry:
"""Test cache provider registration and retrieval."""
def setup_method(self):
"""Clear providers before each test."""
clear_cache_providers()
def teardown_method(self):
"""Clear providers after each test."""
clear_cache_providers()
def test_register_provider(self):
"""Provider should be registered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
assert has_cache_providers() is True
providers = get_cache_providers()
assert len(providers) == 1
assert providers[0] is provider
def test_unregister_provider(self):
"""Provider should be unregistered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
unregister_cache_provider(provider)
assert has_cache_providers() is False
def test_multiple_providers(self):
"""Multiple providers can be registered."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
providers = get_cache_providers()
assert len(providers) == 2
def test_duplicate_registration_ignored(self):
"""Registering same provider twice should be ignored."""
provider = MockCacheProvider()
register_cache_provider(provider)
register_cache_provider(provider) # Should be ignored
providers = get_cache_providers()
assert len(providers) == 1
def test_clear_providers(self):
"""clear_cache_providers should remove all providers."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
clear_cache_providers()
assert has_cache_providers() is False
assert len(get_cache_providers()) == 0
class TestCacheContext:
"""Test CacheContext dataclass."""
def test_context_creation(self):
"""CacheContext should be created with all fields."""
context = CacheContext(
prompt_id="prompt-123",
node_id="node-456",
class_type="KSampler",
cache_key=frozenset([("test", "value")]),
cache_key_bytes=b"hash_bytes",
)
assert context.prompt_id == "prompt-123"
assert context.node_id == "node-456"
assert context.class_type == "KSampler"
assert context.cache_key == frozenset([("test", "value")])
assert context.cache_key_bytes == b"hash_bytes"
class TestCacheValue:
"""Test CacheValue dataclass."""
def test_value_creation(self):
"""CacheValue should be created with outputs."""
outputs = [[{"samples": "tensor_data"}]]
value = CacheValue(outputs=outputs)
assert value.outputs == outputs
class MockCacheProvider(CacheProvider):
"""Mock cache provider for testing."""
def __init__(self):
self.lookups = []
self.stores = []
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
self.lookups.append(context)
return None
def on_store(self, context: CacheContext, value: CacheValue) -> None:
self.stores.append((context, value))
def _torch_available() -> bool:
"""Check if PyTorch is available."""
try:
import torch
return True
except ImportError:
return False