mirror of
https://github.com/unclecode/crawl4ai.git
synced 2026-06-10 15:58:15 +00:00
fix: add TTL expiry for Redis task data to prevent memory growth (#1730)
From PR #1730 by @hoi
This commit is contained in:
@@ -44,7 +44,8 @@ from utils import (
|
||||
get_llm_api_key,
|
||||
validate_llm_provider,
|
||||
get_llm_temperature,
|
||||
get_llm_base_url
|
||||
get_llm_base_url,
|
||||
get_redis_task_ttl
|
||||
)
|
||||
from webhook import WebhookDeliveryService
|
||||
|
||||
@@ -61,6 +62,21 @@ def _get_memory_mb():
|
||||
return None
|
||||
|
||||
|
||||
async def hset_with_ttl(redis, key: str, mapping: dict, config: dict):
|
||||
"""Set Redis hash with automatic TTL expiry.
|
||||
|
||||
Args:
|
||||
redis: Redis client instance
|
||||
key: Redis key (e.g., "task:abc123")
|
||||
mapping: Hash field-value mapping
|
||||
config: Application config containing redis.task_ttl_seconds
|
||||
"""
|
||||
await redis.hset(key, mapping=mapping)
|
||||
ttl = get_redis_task_ttl(config)
|
||||
if ttl > 0:
|
||||
await redis.expire(key, ttl)
|
||||
|
||||
|
||||
async def handle_llm_qa(
|
||||
url: str,
|
||||
query: str,
|
||||
@@ -147,10 +163,10 @@ async def process_llm_extraction(
|
||||
# Validate provider
|
||||
is_valid, error_msg = validate_llm_provider(config, provider)
|
||||
if not is_valid:
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
await hset_with_ttl(redis, f"task:{task_id}", {
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": error_msg
|
||||
})
|
||||
}, config)
|
||||
|
||||
# Send webhook notification on failure
|
||||
await webhook_service.notify_job_completion(
|
||||
@@ -187,10 +203,10 @@ async def process_llm_extraction(
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
await hset_with_ttl(redis, f"task:{task_id}", {
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": result.error_message
|
||||
})
|
||||
}, config)
|
||||
|
||||
# Send webhook notification on failure
|
||||
await webhook_service.notify_job_completion(
|
||||
@@ -210,10 +226,10 @@ async def process_llm_extraction(
|
||||
|
||||
result_data = {"extracted_content": content}
|
||||
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
await hset_with_ttl(redis, f"task:{task_id}", {
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"result": json.dumps(content)
|
||||
})
|
||||
}, config)
|
||||
|
||||
# Send webhook notification on successful completion
|
||||
await webhook_service.notify_job_completion(
|
||||
@@ -227,10 +243,10 @@ async def process_llm_extraction(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM extraction error: {str(e)}", exc_info=True)
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
await hset_with_ttl(redis, f"task:{task_id}", {
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
})
|
||||
}, config)
|
||||
|
||||
# Send webhook notification on failure
|
||||
await webhook_service.notify_job_completion(
|
||||
@@ -438,7 +454,7 @@ async def create_new_task(
|
||||
if webhook_config:
|
||||
task_data["webhook_config"] = json.dumps(webhook_config)
|
||||
|
||||
await redis.hset(f"task:{task_id}", mapping=task_data)
|
||||
await hset_with_ttl(redis, f"task:{task_id}", task_data, config)
|
||||
|
||||
background_tasks.add_task(
|
||||
process_llm_extraction,
|
||||
@@ -799,7 +815,7 @@ async def handle_crawl_job(
|
||||
if webhook_config:
|
||||
task_data["webhook_config"] = json.dumps(webhook_config)
|
||||
|
||||
await redis.hset(f"task:{task_id}", mapping=task_data)
|
||||
await hset_with_ttl(redis, f"task:{task_id}", task_data, config)
|
||||
|
||||
# Initialize webhook service
|
||||
webhook_service = WebhookDeliveryService(config)
|
||||
@@ -812,10 +828,10 @@ async def handle_crawl_job(
|
||||
crawler_config=crawler_config,
|
||||
config=config,
|
||||
)
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
await hset_with_ttl(redis, f"task:{task_id}", {
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"result": json.dumps(result),
|
||||
})
|
||||
}, config)
|
||||
|
||||
# Send webhook notification on successful completion
|
||||
await webhook_service.notify_job_completion(
|
||||
@@ -829,10 +845,10 @@ async def handle_crawl_job(
|
||||
|
||||
await asyncio.sleep(5) # Give Redis time to process the update
|
||||
except Exception as exc:
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
await hset_with_ttl(redis, f"task:{task_id}", {
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(exc),
|
||||
})
|
||||
}, config)
|
||||
|
||||
# Send webhook notification on failure
|
||||
await webhook_service.notify_job_completion(
|
||||
|
||||
@@ -14,20 +14,21 @@ llm:
|
||||
# api_key: sk-... # If you pass the API key directly (not recommended)
|
||||
|
||||
# Redis Configuration
|
||||
# Set task_ttl_seconds to automatically expire task data in Redis.
|
||||
# This prevents unbounded memory growth from accumulated task results.
|
||||
# Default: 3600 (1 hour). Set to 0 to disable TTL (not recommended).
|
||||
# Can be overridden with REDIS_TASK_TTL environment variable.
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
db: 0
|
||||
password: ""
|
||||
task_ttl_seconds: 3600 # TTL for task data (1 hour default)
|
||||
ssl: False
|
||||
ssl_cert_reqs: None
|
||||
ssl_ca_certs: None
|
||||
ssl_certfile: None
|
||||
ssl_keyfile: None
|
||||
ssl_cert_reqs: None
|
||||
ssl_ca_certs: None
|
||||
ssl_certfile: None
|
||||
ssl_keyfile: None
|
||||
|
||||
# Rate Limiting Configuration
|
||||
rate_limiting:
|
||||
|
||||
@@ -36,7 +36,16 @@ def load_config() -> Dict:
|
||||
if llm_api_key and "api_key" not in config["llm"]:
|
||||
config["llm"]["api_key"] = llm_api_key
|
||||
logging.info("LLM API key loaded from LLM_API_KEY environment variable")
|
||||
|
||||
|
||||
# Override Redis task TTL from environment if set
|
||||
redis_task_ttl = os.environ.get("REDIS_TASK_TTL")
|
||||
if redis_task_ttl:
|
||||
try:
|
||||
config["redis"]["task_ttl_seconds"] = int(redis_task_ttl)
|
||||
logging.info(f"Redis task TTL overridden from REDIS_TASK_TTL: {redis_task_ttl}s")
|
||||
except ValueError:
|
||||
logging.warning(f"Invalid REDIS_TASK_TTL value: {redis_task_ttl}, using default")
|
||||
|
||||
return config
|
||||
|
||||
def setup_logging(config: Dict) -> None:
|
||||
@@ -70,6 +79,17 @@ def decode_redis_hash(hash_data: Dict[bytes, bytes]) -> Dict[str, str]:
|
||||
return {k.decode('utf-8'): v.decode('utf-8') for k, v in hash_data.items()}
|
||||
|
||||
|
||||
def get_redis_task_ttl(config: Dict) -> int:
|
||||
"""Get Redis task TTL in seconds from config.
|
||||
|
||||
Args:
|
||||
config: The application configuration dictionary
|
||||
|
||||
Returns:
|
||||
TTL in seconds (default 3600). Returns 0 if TTL is disabled.
|
||||
"""
|
||||
return config.get("redis", {}).get("task_ttl_seconds", 3600)
|
||||
|
||||
|
||||
def get_llm_api_key(config: Dict, provider: Optional[str] = None) -> Optional[str]:
|
||||
"""Get the appropriate API key based on the LLM provider.
|
||||
|
||||
Reference in New Issue
Block a user