diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 63edb9ad0..91669e583 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -230,8 +230,7 @@ class BasicCache: """Notify external providers of cache store (fire-and-forget).""" from comfy_execution.cache_provider import ( _has_cache_providers, _get_cache_providers, - CacheContext, CacheValue, - _serialize_cache_key, _contains_nan, _logger + CacheValue, _contains_nan, _logger ) # Fast exit conditions @@ -269,8 +268,7 @@ class BasicCache: """Check external providers for cached result.""" from comfy_execution.cache_provider import ( _has_cache_providers, _get_cache_providers, - CacheContext, CacheValue, - _contains_nan, _logger + CacheValue, _contains_nan, _logger ) if self._is_subcache: diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py index c7484e1a1..1d9ca095a 100644 --- a/tests-unit/execution_test/test_cache_provider.py +++ b/tests-unit/execution_test/test_cache_provider.py @@ -16,12 +16,12 @@ from comfy_execution.cache_provider import ( CacheValue, register_cache_provider, unregister_cache_provider, - get_cache_providers, - has_cache_providers, - clear_cache_providers, - serialize_cache_key, - contains_nan, - estimate_value_size, + _get_cache_providers, + _has_cache_providers, + _clear_cache_providers, + _serialize_cache_key, + _contains_nan, + _estimate_value_size, _canonicalize, ) @@ -110,15 +110,15 @@ class TestCanonicalize: class TestSerializeCacheKey: - """Test serialize_cache_key for deterministic hashing.""" + """Test _serialize_cache_key for deterministic hashing.""" def test_same_content_same_hash(self): """Same content should produce same hash.""" key1 = frozenset([("node_1", frozenset([("input", "value")]))]) key2 = frozenset([("node_1", frozenset([("input", "value")]))]) - hash1 = serialize_cache_key(key1) - hash2 = serialize_cache_key(key2) + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) assert hash1 == hash2 @@ -127,18 +127,18 @@ class TestSerializeCacheKey: key1 = frozenset([("node_1", "value_a")]) key2 = frozenset([("node_1", "value_b")]) - hash1 = serialize_cache_key(key1) - hash2 = serialize_cache_key(key2) + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) assert hash1 != hash2 - def test_returns_bytes(self): - """Should return bytes (SHA256 digest).""" + def test_returns_hex_string(self): + """Should return hex string (SHA256 hex digest).""" key = frozenset([("test", 123)]) - result = serialize_cache_key(key) + result = _serialize_cache_key(key) - assert isinstance(result, bytes) - assert len(result) == 32 # SHA256 produces 32 bytes + assert isinstance(result, str) + assert len(result) == 64 # SHA256 hex digest is 64 chars def test_complex_nested_structure(self): """Complex nested structures should hash deterministically.""" @@ -155,81 +155,81 @@ class TestSerializeCacheKey: ]) # Hash twice to verify determinism - hash1 = serialize_cache_key(key) - hash2 = serialize_cache_key(key) + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) assert hash1 == hash2 def test_dict_in_cache_key(self): - """Dicts passed directly to serialize_cache_key should work.""" + """Dicts passed directly to _serialize_cache_key should work.""" # This tests the _canonicalize function's ability to handle dicts key = {"node_1": {"input": "value"}, "node_2": 42} - hash1 = serialize_cache_key(key) - hash2 = serialize_cache_key(key) + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) assert hash1 == hash2 - assert isinstance(hash1, bytes) - assert len(hash1) == 32 + assert isinstance(hash1, str) + assert len(hash1) == 64 class TestContainsNan: - """Test contains_nan utility function.""" + """Test _contains_nan utility function.""" def test_nan_float_detected(self): """NaN floats should be detected.""" - assert contains_nan(float('nan')) is True + assert _contains_nan(float('nan')) is True def test_regular_float_not_nan(self): """Regular floats should not be detected as NaN.""" - assert contains_nan(3.14) is False - assert contains_nan(0.0) is False - assert contains_nan(-1.5) is False + assert _contains_nan(3.14) is False + assert _contains_nan(0.0) is False + assert _contains_nan(-1.5) is False def test_infinity_not_nan(self): """Infinity is not NaN.""" - assert contains_nan(float('inf')) is False - assert contains_nan(float('-inf')) is False + assert _contains_nan(float('inf')) is False + assert _contains_nan(float('-inf')) is False def test_nan_in_list(self): """NaN in list should be detected.""" - assert contains_nan([1, 2, float('nan'), 4]) is True - assert contains_nan([1, 2, 3, 4]) is False + assert _contains_nan([1, 2, float('nan'), 4]) is True + assert _contains_nan([1, 2, 3, 4]) is False def test_nan_in_tuple(self): """NaN in tuple should be detected.""" - assert contains_nan((1, float('nan'))) is True - assert contains_nan((1, 2, 3)) is False + assert _contains_nan((1, float('nan'))) is True + assert _contains_nan((1, 2, 3)) is False def test_nan_in_frozenset(self): """NaN in frozenset should be detected.""" - assert contains_nan(frozenset([1, float('nan')])) is True - assert contains_nan(frozenset([1, 2, 3])) is False + assert _contains_nan(frozenset([1, float('nan')])) is True + assert _contains_nan(frozenset([1, 2, 3])) is False def test_nan_in_dict_value(self): """NaN in dict value should be detected.""" - assert contains_nan({"key": float('nan')}) is True - assert contains_nan({"key": 42}) is False + assert _contains_nan({"key": float('nan')}) is True + assert _contains_nan({"key": 42}) is False def test_nan_in_nested_structure(self): """NaN in deeply nested structure should be detected.""" nested = {"level1": [{"level2": (1, 2, float('nan'))}]} - assert contains_nan(nested) is True + assert _contains_nan(nested) is True def test_non_numeric_types(self): """Non-numeric types should not be NaN.""" - assert contains_nan("string") is False - assert contains_nan(None) is False - assert contains_nan(True) is False + assert _contains_nan("string") is False + assert _contains_nan(None) is False + assert _contains_nan(True) is False class TestEstimateValueSize: - """Test estimate_value_size utility function.""" + """Test _estimate_value_size utility function.""" def test_empty_outputs(self): """Empty outputs should have zero size.""" value = CacheValue(outputs=[]) - assert estimate_value_size(value) == 0 + assert _estimate_value_size(value) == 0 @pytest.mark.skipif( not _torch_available(), @@ -243,7 +243,7 @@ class TestEstimateValueSize: tensor = torch.zeros(1000, dtype=torch.float32) value = CacheValue(outputs=[[tensor]]) - size = estimate_value_size(value) + size = _estimate_value_size(value) assert size == 4000 @pytest.mark.skipif( @@ -257,7 +257,7 @@ class TestEstimateValueSize: tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes value = CacheValue(outputs=[[{"samples": tensor}]]) - size = estimate_value_size(value) + size = _estimate_value_size(value) assert size == 400 @@ -266,19 +266,19 @@ class TestProviderRegistry: def setup_method(self): """Clear providers before each test.""" - clear_cache_providers() + _clear_cache_providers() def teardown_method(self): """Clear providers after each test.""" - clear_cache_providers() + _clear_cache_providers() def test_register_provider(self): """Provider should be registered successfully.""" provider = MockCacheProvider() register_cache_provider(provider) - assert has_cache_providers() is True - providers = get_cache_providers() + assert _has_cache_providers() is True + providers = _get_cache_providers() assert len(providers) == 1 assert providers[0] is provider @@ -288,7 +288,7 @@ class TestProviderRegistry: register_cache_provider(provider) unregister_cache_provider(provider) - assert has_cache_providers() is False + assert _has_cache_providers() is False def test_multiple_providers(self): """Multiple providers can be registered.""" @@ -298,7 +298,7 @@ class TestProviderRegistry: register_cache_provider(provider1) register_cache_provider(provider2) - providers = get_cache_providers() + providers = _get_cache_providers() assert len(providers) == 2 def test_duplicate_registration_ignored(self): @@ -308,20 +308,20 @@ class TestProviderRegistry: register_cache_provider(provider) register_cache_provider(provider) # Should be ignored - providers = get_cache_providers() + providers = _get_cache_providers() assert len(providers) == 1 def test_clear_providers(self): - """clear_cache_providers should remove all providers.""" + """_clear_cache_providers should remove all providers.""" provider1 = MockCacheProvider() provider2 = MockCacheProvider() register_cache_provider(provider1) register_cache_provider(provider2) - clear_cache_providers() + _clear_cache_providers() - assert has_cache_providers() is False - assert len(get_cache_providers()) == 0 + assert _has_cache_providers() is False + assert len(_get_cache_providers()) == 0 class TestCacheContext: @@ -333,15 +333,13 @@ class TestCacheContext: prompt_id="prompt-123", node_id="node-456", class_type="KSampler", - cache_key=frozenset([("test", "value")]), - cache_key_bytes=b"hash_bytes", + cache_key_hash="abcdef1234567890" * 4, ) assert context.prompt_id == "prompt-123" assert context.node_id == "node-456" assert context.class_type == "KSampler" - assert context.cache_key == frozenset([("test", "value")]) - assert context.cache_key_bytes == b"hash_bytes" + assert context.cache_key_hash == "abcdef1234567890" * 4 class TestCacheValue: @@ -362,9 +360,9 @@ class MockCacheProvider(CacheProvider): self.lookups = [] self.stores = [] - def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: self.lookups.append(context) return None - def on_store(self, context: CacheContext, value: CacheValue) -> None: + async def on_store(self, context: CacheContext, value: CacheValue) -> None: self.stores.append((context, value))