mirror of
https://github.com/unclecode/crawl4ai.git
synced 2026-06-11 00:08:01 +00:00
ContextVar.reset(token) requires the same Context that created the token. When Starlette's StreamingResponse consumes the async generator in a different Task, the Context changes and reset() raises ValueError. Replaced with set(False) which works across context boundaries. Safe because deep_crawl_active is never nested — the guard on line 21 prevents re-entry.
341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""
|
|
Test Suite: Deep Crawl ContextVar Safety (Issue #1917)
|
|
|
|
Tests that DeepCrawlDecorator's ContextVar (deep_crawl_active) works correctly
|
|
when the async generator is consumed in a different asyncio context, as happens
|
|
with Starlette's StreamingResponse in the Docker API.
|
|
|
|
The bug: base_strategy.py used ContextVar.reset(token) in the generator's finally
|
|
block, but reset() requires the same Context that created the token. When Starlette
|
|
consumes the generator in a different Task, the Context changes -> ValueError.
|
|
|
|
The fix: use ContextVar.set(False) instead of reset(token), which works across
|
|
context boundaries.
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import MagicMock, AsyncMock
|
|
|
|
from crawl4ai.deep_crawling.base_strategy import DeepCrawlDecorator
|
|
|
|
|
|
# ============================================================================
|
|
# Helpers
|
|
# ============================================================================
|
|
|
|
def create_mock_result(url="https://example.com"):
|
|
result = MagicMock()
|
|
result.url = url
|
|
result.success = True
|
|
result.metadata = {}
|
|
result.links = {"internal": [], "external": []}
|
|
return result
|
|
|
|
|
|
def create_streaming_strategy(results):
|
|
"""Create a mock deep crawl strategy that streams results."""
|
|
strategy = MagicMock()
|
|
|
|
async def mock_arun(start_url, crawler, config):
|
|
async def gen():
|
|
for r in results:
|
|
yield r
|
|
return gen()
|
|
|
|
strategy.arun = mock_arun
|
|
return strategy
|
|
|
|
|
|
def create_batch_strategy(results):
|
|
"""Create a mock deep crawl strategy that returns results as a list."""
|
|
strategy = MagicMock()
|
|
|
|
async def mock_arun(start_url, crawler, config):
|
|
return results
|
|
|
|
strategy.arun = mock_arun
|
|
return strategy
|
|
|
|
|
|
def create_config(stream=False, strategy=None):
|
|
config = MagicMock()
|
|
config.stream = stream
|
|
config.deep_crawl_strategy = strategy
|
|
return config
|
|
|
|
|
|
# ============================================================================
|
|
# Tests: ContextVar cross-context safety (the core #1917 bug)
|
|
# ============================================================================
|
|
|
|
class TestContextVarCrossContext:
|
|
"""Tests that deep_crawl_active ContextVar works across task boundaries."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_generator_consumed_in_different_task(self):
|
|
"""
|
|
Core reproduction of issue #1917:
|
|
Create the generator in one task, consume it in another.
|
|
Before the fix, this raised ValueError.
|
|
"""
|
|
mock_results = [create_mock_result(f"https://example.com/{i}") for i in range(3)]
|
|
strategy = create_streaming_strategy(mock_results)
|
|
config = create_config(stream=True, strategy=strategy)
|
|
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
|
|
# Call wrapped_arun in the current task — sets the token
|
|
gen = await wrapped("https://example.com", config=config)
|
|
|
|
# Consume in a DIFFERENT task (simulates Starlette's StreamingResponse)
|
|
collected = []
|
|
|
|
async def consume_in_new_task():
|
|
async for result in gen:
|
|
collected.append(result)
|
|
|
|
task = asyncio.create_task(consume_in_new_task())
|
|
await task
|
|
|
|
assert len(collected) == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_mode_in_different_task(self):
|
|
"""Non-streaming mode should also work across task boundaries."""
|
|
mock_results = [create_mock_result("https://example.com")]
|
|
strategy = create_batch_strategy(mock_results)
|
|
config = create_config(stream=False, strategy=strategy)
|
|
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
result = await wrapped("https://example.com", config=config)
|
|
|
|
assert result == mock_results
|
|
|
|
|
|
# ============================================================================
|
|
# Tests: ContextVar state management
|
|
# ============================================================================
|
|
|
|
class TestContextVarState:
|
|
"""Tests that deep_crawl_active is properly managed."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flag_is_false_after_streaming_completes(self):
|
|
"""deep_crawl_active should be False after the generator is exhausted."""
|
|
strategy = create_streaming_strategy([create_mock_result()])
|
|
config = create_config(stream=True, strategy=strategy)
|
|
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
gen = await wrapped("https://example.com", config=config)
|
|
|
|
async for _ in gen:
|
|
pass
|
|
|
|
assert decorator.deep_crawl_active.get() == False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flag_is_false_after_batch_completes(self):
|
|
"""deep_crawl_active should be False after batch mode completes."""
|
|
strategy = create_batch_strategy([create_mock_result()])
|
|
config = create_config(stream=False, strategy=strategy)
|
|
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
await wrapped("https://example.com", config=config)
|
|
|
|
assert decorator.deep_crawl_active.get() == False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flag_is_true_during_deep_crawl(self):
|
|
"""deep_crawl_active should be True while the generator is being consumed."""
|
|
flag_during_yield = None
|
|
|
|
async def capturing_arun(start_url, crawler, config):
|
|
async def gen():
|
|
nonlocal flag_during_yield
|
|
flag_during_yield = DeepCrawlDecorator.deep_crawl_active.get()
|
|
yield create_mock_result(start_url)
|
|
return gen()
|
|
|
|
strategy = MagicMock()
|
|
strategy.arun = capturing_arun
|
|
config = create_config(stream=True, strategy=strategy)
|
|
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
gen = await wrapped("https://example.com", config=config)
|
|
|
|
async for _ in gen:
|
|
pass
|
|
|
|
assert flag_during_yield == True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flag_prevents_recursive_deep_crawl(self):
|
|
"""When deep_crawl_active is True, nested calls should skip deep crawl."""
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
inner_call_hit = False
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
nonlocal inner_call_hit
|
|
inner_call_hit = True
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
|
|
# Manually set the flag to simulate being inside a deep crawl
|
|
decorator.deep_crawl_active.set(True)
|
|
try:
|
|
strategy = create_batch_strategy([create_mock_result()])
|
|
config = create_config(stream=False, strategy=strategy)
|
|
# Should call original_arun directly, NOT go through strategy
|
|
result = await wrapped("https://example.com", config=config)
|
|
assert inner_call_hit == True
|
|
finally:
|
|
decorator.deep_crawl_active.set(False)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flag_reset_after_streaming_error(self):
|
|
"""deep_crawl_active should be reset even if the generator raises."""
|
|
strategy = MagicMock()
|
|
|
|
async def failing_arun(start_url, crawler, config):
|
|
async def gen():
|
|
yield create_mock_result("https://example.com")
|
|
raise RuntimeError("simulated error")
|
|
return gen()
|
|
|
|
strategy.arun = failing_arun
|
|
config = create_config(stream=True, strategy=strategy)
|
|
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
gen = await wrapped("https://example.com", config=config)
|
|
|
|
with pytest.raises(RuntimeError, match="simulated error"):
|
|
async for _ in gen:
|
|
pass
|
|
|
|
assert decorator.deep_crawl_active.get() == False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flag_reset_after_streaming_error_in_different_task(self):
|
|
"""
|
|
Combines #1917 fix with error handling: generator raises in a different task.
|
|
Both the cross-context issue and error cleanup must work together.
|
|
"""
|
|
strategy = MagicMock()
|
|
|
|
async def failing_arun(start_url, crawler, config):
|
|
async def gen():
|
|
yield create_mock_result("https://example.com")
|
|
raise RuntimeError("simulated error")
|
|
return gen()
|
|
|
|
strategy.arun = failing_arun
|
|
config = create_config(stream=True, strategy=strategy)
|
|
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
gen = await wrapped("https://example.com", config=config)
|
|
|
|
error_caught = False
|
|
|
|
async def consume_in_new_task():
|
|
nonlocal error_caught
|
|
try:
|
|
async for _ in gen:
|
|
pass
|
|
except RuntimeError:
|
|
error_caught = True
|
|
|
|
task = asyncio.create_task(consume_in_new_task())
|
|
await task
|
|
|
|
assert error_caught == True
|
|
|
|
|
|
# ============================================================================
|
|
# Tests: Concurrent requests
|
|
# ============================================================================
|
|
|
|
class TestConcurrentRequests:
|
|
"""Tests that multiple concurrent streaming deep crawls don't interfere."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_streaming_in_separate_tasks(self):
|
|
"""
|
|
Multiple concurrent streaming requests consumed in separate tasks.
|
|
This simulates multiple clients hitting /crawl/stream simultaneously.
|
|
"""
|
|
crawler = MagicMock()
|
|
decorator = DeepCrawlDecorator(crawler)
|
|
|
|
async def original_arun(url, config=None, **kwargs):
|
|
return create_mock_result(url)
|
|
|
|
wrapped = decorator(original_arun)
|
|
results_per_request = {}
|
|
|
|
async def simulate_request(request_id):
|
|
mock_results = [create_mock_result(f"https://example.com/{request_id}/{i}") for i in range(2)]
|
|
strategy = create_streaming_strategy(mock_results)
|
|
config = create_config(stream=True, strategy=strategy)
|
|
|
|
gen = await wrapped(f"https://example.com/{request_id}", config=config)
|
|
results = []
|
|
async for result in gen:
|
|
results.append(result)
|
|
await asyncio.sleep(0.01) # Interleave with other requests
|
|
results_per_request[request_id] = results
|
|
|
|
tasks = [asyncio.create_task(simulate_request(i)) for i in range(3)]
|
|
await asyncio.gather(*tasks)
|
|
|
|
assert len(results_per_request) == 3
|
|
for request_id, results in results_per_request.items():
|
|
assert len(results) == 2
|
|
|
|
assert decorator.deep_crawl_active.get() == False
|