Files
tabbyAPI/tests/test_context_length_errors.py
turboderp 538654bfcb Add tests
2026-06-27 02:37:30 +02:00

219 lines
8.0 KiB
Python

import unittest
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
import httpx
from fastapi import FastAPI, HTTPException
from common import model
from common.errors import (
ContextLengthExceededError,
ContextLengthHTTPException,
context_length_exception_handler,
validate_context_requirements,
)
from common.networking import get_context_length_generator_error
from endpoints.Kobold.utils import generation as kobold_generation
from endpoints.OAI.utils import chat_completion, completion
class DummyDisconnectHandler:
async def cleanup(self):
pass
class DummyRequestData:
n = 1
stream_options = None
def model_copy(self, deep=False):
return self
def model_dump(self, mode=None):
return {}
def request_with_id(request_id="request-id"):
return SimpleNamespace(state=SimpleNamespace(id=request_id))
class ContextLengthErrorTests(unittest.IsolatedAsyncioTestCase):
async def test_context_length_http_error_uses_openai_error_shape(self):
app = FastAPI()
app.add_exception_handler(ContextLengthHTTPException, context_length_exception_handler)
@app.get("/")
async def raise_context_error():
raise ContextLengthHTTPException("Prompt exceeds the available context size")
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/")
self.assertEqual(response.status_code, 400)
self.assertEqual(
response.json(),
{
"error": {
"message": "Prompt exceeds the available context size",
"type": "invalid_request_error",
"param": None,
"code": "context_length_exceeded",
}
},
)
def test_context_length_stream_error_uses_openai_error_shape(self):
message = "Prompt exceeds the available context size"
self.assertEqual(
httpx.Response(200, text=get_context_length_generator_error(message)).json(),
{
"error": {
"message": message,
"type": "invalid_request_error",
"param": None,
"code": "context_length_exceeded",
}
},
)
def test_prompt_limit_message_is_litellm_compatible(self):
with self.assertRaises(ContextLengthExceededError) as raised:
validate_context_requirements(4097, 4096, 1, 4096)
self.assertIn("exceeds the available context size", str(raised.exception))
def test_exllamav3_preflight_accounts_for_requested_completion(self):
with self.assertRaises(ContextLengthExceededError) as raised:
validate_context_requirements(3500, 8192, 1000, 4096)
self.assertIn("requires 4500 cache tokens", str(raised.exception))
self.assertIn("exceeds the available context size of 4096 tokens", str(raised.exception))
def test_exllamav3_preflight_allows_requeueable_completion(self):
validate_context_requirements(1000, 8192, 1000, 4096, max_rq_tokens=2048)
def test_exllamav3_preflight_uses_automatic_completion_limit(self):
validate_context_requirements(3000, 4096, 0, 4096)
def test_exllamav3_preflight_accounts_for_requeue_allocation_window(self):
with self.assertRaises(ContextLengthExceededError) as raised:
validate_context_requirements(4000, 8192, 10, 4096, max_rq_tokens=2048)
self.assertIn("requires 6144 cache tokens", str(raised.exception))
def test_exllamav3_preflight_ignores_later_requeue_windows(self):
validate_context_requirements(1000, 8192, 10_000_000, 4096, max_rq_tokens=2048)
def test_streaming_preflight_returns_400_for_context_length_error(self):
error = ContextLengthExceededError("Prompt length 9 is greater than max_seq_len 8")
container = SimpleNamespace(
validate_context_length=lambda *args: (_ for _ in ()).throw(error)
)
original_container = model.container
model.container = container
try:
with self.assertRaises(HTTPException) as raised:
model.check_context_length("prompt", DummyRequestData())
finally:
model.container = original_container
self.assertEqual(raised.exception.status_code, 400)
self.assertEqual(raised.exception.detail, str(error))
def test_streaming_preflight_checks_each_batched_prompt(self):
checked_prompts = []
container = SimpleNamespace(
validate_context_length=lambda prompt, *args: checked_prompts.append(prompt)
)
original_container = model.container
model.container = container
try:
model.check_context_length(["first", "second"], DummyRequestData())
finally:
model.container = original_container
self.assertEqual(checked_prompts, ["first", "second"])
async def test_completion_returns_400_for_context_length_error(self):
error = ContextLengthExceededError("Prompt length 9 is greater than max_seq_len 8")
async def collector(*args, **kwargs):
return error
with patch.object(completion, "_stream_collector", collector):
with self.assertRaises(HTTPException) as raised:
await completion.generate_completion(
"prompt",
DummyRequestData(),
request_with_id(),
Path("model"),
DummyDisconnectHandler(),
)
self.assertEqual(raised.exception.status_code, 400)
self.assertEqual(raised.exception.detail, str(error))
async def test_chat_completion_returns_400_for_context_length_error(self):
error = ContextLengthExceededError("Prompt length 9 is greater than max_seq_len 8")
async def collector(*args, **kwargs):
return error
original_container = model.container
model.container = SimpleNamespace(reasoning=False)
try:
with patch.object(chat_completion, "_chat_stream_collector", collector):
with self.assertRaises(HTTPException) as raised:
await chat_completion.generate_chat_completion(
"prompt",
None,
DummyRequestData(),
request_with_id(),
Path("model"),
DummyDisconnectHandler(),
)
finally:
model.container = original_container
self.assertEqual(raised.exception.status_code, 400)
self.assertEqual(raised.exception.detail, str(error))
async def test_kobold_generation_returns_400_for_context_length_error(self):
error = ContextLengthExceededError("Prompt length 9 is greater than max_seq_len 8")
data = SimpleNamespace(genkey=None)
async def collector(*args, **kwargs):
raise error
yield
with patch.object(kobold_generation, "_stream_collector", collector):
with self.assertRaises(HTTPException) as raised:
await kobold_generation.get_generation(data, request_with_id())
self.assertEqual(raised.exception.status_code, 400)
self.assertEqual(raised.exception.detail, str(error))
async def test_other_completion_errors_remain_503(self):
async def collector(*args, **kwargs):
return ValueError("backend failure")
with patch.object(completion, "_stream_collector", collector):
with self.assertRaises(HTTPException) as raised:
await completion.generate_completion(
"prompt",
DummyRequestData(),
request_with_id(),
Path("model"),
DummyDisconnectHandler(),
)
self.assertEqual(raised.exception.status_code, 503)
if __name__ == "__main__":
unittest.main()