Merge branch 'main' into main

This commit is contained in:
Bartowski
2024-08-14 16:16:15 -04:00
committed by GitHub
56 changed files with 2609 additions and 888 deletions

View File

@@ -1,35 +0,0 @@
---
name: Bug report
about: Report code related issues
title: "[BUG]"
labels: bug
assignees: ''
---
**Disclaimer:** Github Issues are **only** for code related bugs. If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj)
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Logs**
If applicable, add logs and tracebacks to help explain your problem.
**System info** (Bugs without this information will go lower on our priority list!)
- OS: [ex. Windows]
- Python version: [ex. 3.11]
- CUDA/ROCm version: [ex. 12.x]
- Python version: [ex. 3.11]
**Additional context**
Add any other context about the problem here.

97
.github/ISSUE_TEMPLATE/bug_report.yaml vendored Normal file
View File

@@ -0,0 +1,97 @@
name: Bug report
description: Report code related issues
title: "[BUG]"
labels: bug
body:
- type: markdown
attributes:
value: |
### Disclaimer:
Github Issues are **only** for code related bugs.
If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj)
- type: dropdown
attributes:
label: OS
options:
- Windows
- Linux
validations:
required: true
- type: dropdown
attributes:
label: GPU Library
description: Ex. CUDA, ROCm
options:
- CUDA 12.x
- CUDA 11.8
- AMD ROCm
validations:
required: true
- type: dropdown
attributes:
label: Python version
options:
- '3.12'
- '3.11'
- '3.10'
validations:
required: true
- type: textarea
attributes:
label: Describe the bug
description: A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
attributes:
label: Reproduction steps
description: Walk us through how the bug occurred and how to make it happen.
validations:
required: true
- type: textarea
attributes:
label: Expected behavior
description: What was expected to happen?
validations:
required: true
- type: textarea
attributes:
label: Logs
description: If applicable, add logs and tracebacks to help explain your problem.
validations:
required: false
- type: textarea
attributes:
label: Additional context
description: Add any other context about the problem here.
validations:
required: false
- type: checkboxes
attributes:
label: Acknowledgements
description: Before submitting this issue, please make sure you have completed the following checklist.
options:
- label: I have looked for similar issues before submitting this one.
required: true
- label: I have read the disclaimer, and this issue is related to a code bug. If I have a question, I will use the Discord server.
required: true
- label: I understand that the developers have lives and my issue will be answered when possible.
required: true
- label: I understand the developers of this program are human, and I will ask my questions politely.
required: true
- type: markdown
attributes:
value: |
## Thanks!
Well-formatted issues improve TabbyAPI and make the development process smoother.

View File

@@ -1,26 +0,0 @@
---
name: Feature request
about: Suggest a new idea
title: "[REQUEST]"
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Why should this feature be added?**
An explanation of why the feature should be added. Please be as specific as possible to help us understand the reasoning.
**Examples**
Examples of the feature in action and its significance compared to not having the feature.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@@ -0,0 +1,69 @@
name: Feature request
description: Suggest a new idea
title: "[REQUEST]"
body:
- type: textarea
attributes:
label: Problem
description: Is the feature request related to a problem? If so, please describe.
placeholder: A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
validations:
required: false
- type: textarea
attributes:
label: Solution
description: Describe the solution you'd like.
placeholder: A clear and concise description of what you want to happen.
validations:
required: true
- type: textarea
attributes:
label: Alternatives
description: What alternative options did you consider?
validations:
required: false
- type: textarea
attributes:
label: Explanation
description: Why should this feature be added?
validations:
required: true
- type: textarea
attributes:
label: Examples
description: |
Examples of the feature in action and its significance.
Not required, but will make your request easier to understand.
validations:
required: false
- type: textarea
attributes:
label: Additional context
description: Anything else to add?
validations:
required: false
- type: checkboxes
attributes:
label: Acknowledgements
description: Before submitting this issue, please make sure you have completed the following checklist.
options:
- label: I have looked for similar requests before submitting this one.
required: true
- label: I understand that the developers have lives and my issue will be answered when possible.
required: true
- label: I understand the developers of this program are human, and I will make my requests politely.
required: true
- type: markdown
attributes:
value: |
## Thanks!
Well-formatted issues improve TabbyAPI and make the development process smoother.

View File

@@ -47,12 +47,17 @@ jobs:
run: |
npm install @redocly/cli -g
- name: Export OpenAPI docs
run: EXPORT_OPENAPI=1 python main.py
run: |
EXPORT_OPENAPI=1 python main.py
mv openapi.json openapi-oai.json
EXPORT_OPENAPI=1 python main.py --api-servers kobold
mv openapi.json openapi-kobold.json
- name: Build and store Redocly site
run: |
redocly build-docs openapi.json
mkdir static
mv redoc-static.html static/index.html
mkdir static/kobold
redocly build-docs openapi-oai.json -o static/index.html
redocly build-docs openapi-kobold.json -o static/kobold/index.html
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Upload artifact

6
.gitignore vendored
View File

@@ -200,5 +200,11 @@ sampler_overrides/*
# Gpu lib preferences file
gpu_lib.txt
# Start options file
start_options.json
# OpenAPI JSON
openapi.json
# Infinity-emb cache
.infinity_cache/

View File

@@ -32,20 +32,41 @@
A FastAPI based application that allows for generating text using an LLM (large language model) using the [Exllamav2 backend](https://github.com/turboderp/exllamav2)
TabbyAPI is also the official API backend server for ExllamaV2.
## Disclaimer
This project is marked rolling release. There may be bugs and changes down the line. Please be aware that you might need to reinstall dependencies if needed.
This project is marked as rolling release. There may be bugs and changes down the line. Please be aware that you might need to reinstall dependencies if needed.
TabbyAPI is a hobby project solely for a small amount of users. It is not meant to run on production servers. For that, please look at other backends that support those workloads.
TabbyAPI is a hobby project made for a small amount of users. It is not meant to run on production servers. For that, please look at other solutions that support those workloads.
## Getting Started
> [!IMPORTANT]
>
> This README is not for getting started. Please read the Wiki.
> This README does not have instructions for setting up. Please read the Wiki.
Read the [Wiki](https://github.com/theroyallab/tabbyAPI/wiki/1.-Getting-Started) for more information. It contains user-facing documentation for installation, configuration, sampling, API usage, and so much more.
## Features
- OpenAI compatible API
- Loading/unloading models
- HuggingFace model downloading
- Embedding model support
- JSON schema + Regex + EBNF support
- AI Horde support
- Speculative decoding via draft models
- Multi-lora with independent scaling (ex. a weight of 0.9)
- Inbuilt proxy to override client request parameters/samplers
- Flexible Jinja2 template engine for chat completions that conforms to HuggingFace
- Concurrent inference with asyncio
- Utilizes modern python paradigms
- Continuous batching engine using paged attention
- Fast classifer-free guidance
And much more. If something is missing here, PR it in!
## Supported Model Types
TabbyAPI uses Exllamav2 as a powerful and fast backend for model inference, loading, etc. Therefore, the following types of models are supported:
@@ -58,18 +79,6 @@ TabbyAPI uses Exllamav2 as a powerful and fast backend for model inference, load
In addition, TabbyAPI supports parallel batching using paged attention for Nvidia Ampere GPUs and higher.
#### Alternative Loaders/Backends
If you want to use a different model type or quantization method than the ones listed above, here are some alternative backends with their own APIs:
- GGUF + GGML - [KoboldCPP](https://github.com/lostruins/KoboldCPP)
- Production ready + Many other quants + batching - [Aphrodite Engine](https://github.com/PygmalionAI/Aphrodite-engine)
- Production ready + batching - [VLLM](https://github.com/vllm-project/vllm)
- [Text Generation WebUI](https://github.com/oobabooga/text-generation-webui)
## Contributing
Use the template when creating issues or pull requests, otherwise the developers may not look at your post.
@@ -84,6 +93,17 @@ If you have a Pull Request
- Describe the pull request in detail, what, and why you are changing something
## Acknowldgements
TabbyAPI would not exist without the work of other contributors and FOSS projects:
- [ExllamaV2](https://github.com/turboderp/exllamav2)
- [Aphrodite Engine](https://github.com/PygmalionAI/Aphrodite-engine)
- [infinity-emb](https://github.com/michaelfeil/infinity)
- [FastAPI](https://github.com/fastapi/fastapi)
- [Text Generation WebUI](https://github.com/oobabooga/text-generation-webui)
- [SillyTavern](https://github.com/SillyTavern/SillyTavern)
## Developers and Permissions
Creators/Developers:

View File

@@ -47,7 +47,7 @@ from common.templating import (
TemplateLoadError,
find_template_from_model,
)
from common.transformers_utils import GenerationConfig
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
from common.utils import coalesce, unwrap
@@ -70,8 +70,9 @@ class ExllamaV2Container:
cache_size: int = None
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
max_batch_size: int = 20
max_batch_size: Optional[int] = None
generation_config: Optional[GenerationConfig] = None
hf_config: Optional[HuggingFaceConfig] = None
# GPU split vars
gpu_split: Optional[list] = None
@@ -186,6 +187,9 @@ class ExllamaV2Container:
except AttributeError:
pass
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
@@ -213,6 +217,9 @@ class ExllamaV2Container:
# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
# Set max batch size to the config override
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
# Check whether the user's configuration supports flash/paged attention
# Also check if exl2 has disabled flash attention
if (
@@ -268,15 +275,8 @@ class ExllamaV2Container:
else:
self.cache_size = self.config.max_seq_len
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Load generation config overrides
generation_config_path = (
pathlib.Path(self.config.model_dir) / "generation_config.json"
)
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
@@ -288,6 +288,11 @@ class ExllamaV2Container:
"Skipping generation config load because of an unexpected error."
)
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Catch all for template lookup errors
if self.prompt_template:
logger.info(
@@ -405,6 +410,9 @@ class ExllamaV2Container:
def get_model_path(self, is_draft: bool = False):
"""Get the path for this model."""
if is_draft and not self.draft_config:
return None
model_path = pathlib.Path(
self.draft_config.model_dir if is_draft else self.config.model_dir
)
@@ -446,6 +454,11 @@ class ExllamaV2Container:
# Immediately abort all jobs if asked
if skip_wait:
logger.warning(
"Immediately terminating all jobs. "
"Clients will have their requests cancelled.\n"
)
# Requires a copy to avoid errors during iteration
jobs_copy = self.generator.jobs.copy()
for job in jobs_copy.values():
@@ -485,15 +498,7 @@ class ExllamaV2Container:
yield value
# Create async generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
paged=self.paged,
)
await self.create_generator()
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
@@ -642,6 +647,34 @@ class ExllamaV2Container:
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
async def create_generator(self):
try:
# Don't acquire locks unless a model is loaded
if self.model_loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
paged=self.paged,
)
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.model_loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
def get_loras(self):
"""Convenience function to get all loras."""
@@ -701,11 +734,15 @@ class ExllamaV2Container:
Free all VRAM resources used by this model
"""
try:
await self.load_lock.acquire()
# Shutdown immediately unloads and bypasses all locks
do_shutdown = kwargs.get("shutdown")
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
try:
if not do_shutdown:
await self.load_lock.acquire()
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
# Delete references held in the grammar module
clear_grammar_func_cache()
@@ -745,10 +782,11 @@ class ExllamaV2Container:
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
finally:
self.load_lock.release()
if not do_shutdown:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async with self.load_condition:
self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string"""
@@ -800,10 +838,14 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values))
async def generate(self, prompt: str, **kwargs):
async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
):
"""Generate a response to a prompt"""
generations = []
async for generation in self.generate_gen(prompt, **kwargs):
async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
):
generations.append(generation)
joined_generation = {
@@ -853,7 +895,11 @@ class ExllamaV2Container:
return kwargs
async def generate_gen(
self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs
self,
prompt: str,
request_id: str,
abort_event: Optional[asyncio.Event] = None,
**kwargs,
):
"""
Create generator function for prompt completion.
@@ -972,8 +1018,37 @@ class ExllamaV2Container:
kwargs.get("repetition_decay"), fallback_decay, 0
)
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer
)
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
# Set banned strings
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
if banned_strings and len(grammar_handler.filters) > 0:
logger.warning(
"Disabling banned_strings because "
"they cannot be used with grammar filters."
)
banned_strings = []
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias")
@@ -1021,26 +1096,6 @@ class ExllamaV2Container:
"in the model's vocab. Skipping."
)
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer
)
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
@@ -1085,34 +1140,11 @@ class ExllamaV2Container:
# This is an inverse of skip_special_tokens
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
max_tokens=max_tokens,
min_tokens=min_tokens,
stream=kwargs.get("stream"),
**gen_settings_log_dict,
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
generate_window=generate_window,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
skip_special_tokens=not decode_special_tokens,
speculative_ngram=self.generator.speculative_ngram,
logprobs=request_logprobs,
stop_conditions=stop_conditions,
banned_tokens=banned_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias,
filters=grammar_handler.filters,
)
# Log prompt to console
log_prompt(prompt, negative_prompt)
log_prompt(prompt, request_id, negative_prompt)
# Create and add a new job
# Don't use the request ID here as there can be multiple jobs per request
job_id = uuid.uuid4().hex
job = ExLlamaV2DynamicJobAsync(
self.generator,
@@ -1138,6 +1170,7 @@ class ExllamaV2Container:
max_seq_len = self.config.max_seq_len
generated_tokens = 0
full_response = ""
metrics_result = {}
# Get the generation status once it's ready
try:
@@ -1191,23 +1224,15 @@ class ExllamaV2Container:
# Second yield if eos is true
if result.get("eos"):
log_response(full_response)
log_response(request_id, full_response)
eos_reason = result.get("eos_reason")
finish_reason = (
"length" if eos_reason == "max_new_tokens" else "stop"
)
log_metrics(
result.get("time_enqueued"),
result.get("prompt_tokens"),
result.get("cached_tokens"),
result.get("time_prefill"),
result.get("new_tokens"),
result.get("time_generate"),
context_len,
max_seq_len,
)
# Save the final result for metrics logging
metrics_result = result
# Remove the token text
generation = {
@@ -1220,3 +1245,53 @@ class ExllamaV2Container:
break
except asyncio.CancelledError:
await job.cancel()
except Exception as ex:
# Create a new generator since the current state is broken
# No need to wait for this to finish
logger.error(
"FATAL ERROR with generation. "
"Attempting to recreate the generator. "
"If this fails, please restart the server.\n"
)
asyncio.ensure_future(self.create_generator())
raise ex
finally:
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
request_id=request_id,
max_tokens=max_tokens,
min_tokens=min_tokens,
stream=kwargs.get("stream"),
**gen_settings_log_dict,
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
generate_window=generate_window,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
skip_special_tokens=not decode_special_tokens,
speculative_ngram=self.generator.speculative_ngram,
logprobs=request_logprobs,
stop_conditions=stop_conditions,
banned_tokens=banned_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias,
filters=grammar_handler.filters,
)
# Log the metrics if present
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len,
max_seq_len,
)

View File

@@ -1,20 +1,22 @@
import platform
import torch
from packaging import version
from importlib.metadata import PackageNotFoundError, version as package_version
from loguru import logger
import torch
def check_exllama_version():
"""Verifies the exllama version"""
required_version = version.parse("0.1.6")
required_version = version.parse("0.1.7")
current_version = version.parse(package_version("exllamav2").split("+")[0])
unsupported_message = (
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
f"or greater. Your current version is {current_version}.\n"
"Please upgrade your environment by running a start script "
"(start.bat or start.sh)\n\n"
"Please update your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
@@ -71,8 +73,9 @@ def supports_paged_attn():
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"Please upgrade your environment by running a start script "
"(start.bat or start.sh)\n\n"
"Please upgrade your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"

View File

@@ -0,0 +1,66 @@
import gc
import pathlib
import torch
from loguru import logger
from typing import List, Optional
from common.utils import unwrap
# Conditionally import infinity to sidestep its logger
# TODO: Make this prettier
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
has_infinity_emb = True
except ImportError:
has_infinity_emb = False
class InfinityContainer:
model_dir: pathlib.Path
model_is_loading: bool = False
model_loaded: bool = False
# Conditionally set the type hint based on importablity
# TODO: Clean this up
if has_infinity_emb:
engine: Optional[AsyncEmbeddingEngine] = None
else:
engine = None
def __init__(self, model_directory: pathlib.Path):
self.model_dir = model_directory
async def load(self, **kwargs):
self.model_is_loading = True
# Use cpu by default
device = unwrap(kwargs.get("embeddings_device"), "cpu")
engine_args = EngineArgs(
model_name_or_path=str(self.model_dir),
engine="torch",
device=device,
bettertransformer=False,
model_warmup=False,
)
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
await self.engine.astart()
self.model_loaded = True
logger.info("Embedding model successfully loaded.")
async def unload(self):
await self.engine.astop()
self.engine = None
gc.collect()
torch.cuda.empty_cache()
logger.info("Embedding model unloaded.")
async def generate(self, sentence_input: List[str]):
result_embeddings, usage = await self.engine.embed(sentence_input)
return {"embeddings": result_embeddings, "usage": usage}

View File

@@ -17,13 +17,16 @@ def init_argparser():
"""Creates an argument parser that any function can use"""
parser = argparse.ArgumentParser(
epilog="These args are only for a subset of the config. "
+ "Please edit config.yml for all options!"
epilog="NOTE: These args serve to override parts of the config. "
+ "It's highly recommended to edit config.yml for all options and "
+ "better descriptions!"
)
add_network_args(parser)
add_model_args(parser)
add_embeddings_args(parser)
add_logging_args(parser)
add_developer_args(parser)
add_sampling_args(parser)
add_config_args(parser)
return parser
@@ -64,6 +67,17 @@ def add_network_args(parser: argparse.ArgumentParser):
type=str_to_bool,
help="Disable HTTP token authenticaion with requests",
)
network_group.add_argument(
"--send-tracebacks",
type=str_to_bool,
help="Decide whether to send error tracebacks over the API",
)
network_group.add_argument(
"--api-servers",
type=str,
nargs="+",
help="API servers to enable. Options: (OAI, Kobold)",
)
def add_model_args(parser: argparse.ArgumentParser):
@@ -74,6 +88,17 @@ def add_model_args(parser: argparse.ArgumentParser):
"--model-dir", type=str, help="Overrides the directory to look for models"
)
model_group.add_argument("--model-name", type=str, help="An initial model to load")
model_group.add_argument(
"--use-dummy-models",
type=str_to_bool,
help="Add dummy OAI model names for API queries",
)
model_group.add_argument(
"--use-as-default",
type=str,
nargs="+",
help="Names of args to use as a default fallback for API load requests ",
)
model_group.add_argument(
"--max-seq-len", type=int, help="Override the maximum model sequence length"
)
@@ -82,25 +107,17 @@ def add_model_args(parser: argparse.ArgumentParser):
type=str_to_bool,
help="Overrides base model context length",
)
model_group.add_argument(
"--cache-size",
type=int,
help="The size of the prompt cache (in number of tokens) to allocate",
)
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
model_group.add_argument(
"--prompt-template",
type=str,
help="Set the prompt template for chat completions",
)
model_group.add_argument(
"--gpu-split-auto",
type=str_to_bool,
help="Automatically allocate resources to GPUs",
)
model_group.add_argument(
"--autosplit-reserve",
type=int,
nargs="+",
help="Reserve VRAM used for autosplit loading (in MBs) ",
)
model_group.add_argument(
"--gpu-split",
type=float,
@@ -108,15 +125,44 @@ def add_model_args(parser: argparse.ArgumentParser):
help="An integer array of GBs of vram to split between GPUs. "
+ "Ignored if gpu_split_auto is true",
)
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
model_group.add_argument(
"--cache-mode",
type=str,
help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)",
)
model_group.add_argument(
"--cache-size",
type=int,
help="The size of the prompt cache (in number of tokens) to allocate",
)
model_group.add_argument(
"--chunk-size",
type=int,
help="Chunk size for prompt ingestion",
)
model_group.add_argument(
"--max-batch-size",
type=int,
help="Maximum amount of prompts to process at one time",
)
model_group.add_argument(
"--prompt-template",
type=str,
help="Set the jinja2 prompt template for chat completions",
)
model_group.add_argument(
"--num-experts-per-token",
type=int,
help="Number of experts to use per token in MoE models",
)
model_group.add_argument(
"--use-cfg",
"--fasttensors",
type=str_to_bool,
help="Enables CFG support",
help="Possibly increases model loading speeds",
)
@@ -132,6 +178,11 @@ def add_logging_args(parser: argparse.ArgumentParser):
type=str_to_bool,
help="Enable generation parameter logging",
)
logging_group.add_argument(
"--log-requests",
type=str_to_bool,
help="Enable request logging",
)
def add_developer_args(parser: argparse.ArgumentParser):
@@ -149,5 +200,38 @@ def add_developer_args(parser: argparse.ArgumentParser):
developer_group.add_argument(
"--cuda-malloc-backend",
type=str_to_bool,
help="Disables API request streaming",
help="Runs with the pytorch CUDA malloc backend",
)
developer_group.add_argument(
"--uvloop",
type=str_to_bool,
help="Run asyncio using Uvloop or Winloop",
)
def add_sampling_args(parser: argparse.ArgumentParser):
"""Adds sampling-specific arguments"""
sampling_group = parser.add_argument_group("sampling")
sampling_group.add_argument(
"--override-preset", type=str, help="Select a sampler override preset"
)
def add_embeddings_args(parser: argparse.ArgumentParser):
"""Adds arguments specific to embeddings"""
embeddings_group = parser.add_argument_group("embeddings")
embeddings_group.add_argument(
"--embedding-model-dir",
type=str,
help="Overrides the directory to look for models",
)
embeddings_group.add_argument(
"--embedding-model-name", type=str, help="An initial model to load"
)
embeddings_group.add_argument(
"--embeddings-device",
type=str,
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
)

View File

@@ -5,11 +5,13 @@ application, it should be fine.
import secrets
import yaml
from fastapi import Header, HTTPException
from fastapi import Header, HTTPException, Request
from pydantic import BaseModel
from loguru import logger
from typing import Optional
from common.utils import coalesce
class AuthKeys(BaseModel):
"""
@@ -75,7 +77,27 @@ def load_auth_keys(disable_from_config: bool):
)
async def validate_key_permission(test_key: str):
def get_key_permission(request: Request):
"""
Gets the key permission from a request.
Internal only! Use the depends functions for incoming requests.
"""
# Give full admin permissions if auth is disabled
if DISABLE_AUTH:
return "admin"
# Hyphens are okay here
test_key = coalesce(
request.headers.get("x-admin-key"),
request.headers.get("x-api-key"),
request.headers.get("authorization"),
)
if test_key is None:
raise ValueError("The provided authentication key is missing.")
if test_key.lower().startswith("bearer"):
test_key = test_key.split(" ")[1]

View File

@@ -46,15 +46,12 @@ def from_args(args: dict):
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
# Generation Logging config
gen_logging_override = args.get("logging")
if gen_logging_override:
cur_gen_logging_config = gen_logging_config()
logging_override = args.get("logging")
if logging_override:
cur_logging_config = logging_config()
GLOBAL_CONFIG["logging"] = {
**cur_gen_logging_config,
**{
k.replace("log_", ""): gen_logging_override[k]
for k in gen_logging_override
},
**cur_logging_config,
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
}
developer_override = args.get("developer")
@@ -62,6 +59,11 @@ def from_args(args: dict):
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
embeddings_override = args.get("embeddings")
if embeddings_override:
cur_embeddings_config = embeddings_config()
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
def sampling_config():
"""Returns the sampling parameter config from the global config"""
@@ -90,11 +92,16 @@ def network_config():
return unwrap(GLOBAL_CONFIG.get("network"), {})
def gen_logging_config():
"""Returns the generation logging config from the global config"""
def logging_config():
"""Returns the logging config from the global config"""
return unwrap(GLOBAL_CONFIG.get("logging"), {})
def developer_config():
"""Returns the developer specific config from the global config"""
return unwrap(GLOBAL_CONFIG.get("developer"), {})
def embeddings_config():
"""Returns the embeddings config from the global config"""
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})

View File

@@ -101,6 +101,7 @@ async def hf_repo_download(
chunk_limit: Optional[float],
include: Optional[List[str]],
exclude: Optional[List[str]],
timeout: Optional[int],
repo_type: Optional[str] = "model",
):
"""Gets a repo's information from HuggingFace and downloads it locally."""
@@ -145,7 +146,8 @@ async def hf_repo_download(
logger.info(f"Saving {repo_id} to {str(download_path)}")
try:
async with aiohttp.ClientSession() as session:
client_timeout = aiohttp.ClientTimeout(total=timeout) # Turn off timeout
async with aiohttp.ClientSession(timeout=client_timeout) as session:
tasks = []
logger.info(f"Starting download for {repo_id}")

View File

@@ -51,25 +51,31 @@ def log_generation_params(**kwargs):
logger.info(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str, negative_prompt: Optional[str]):
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
"""Logs the prompt to console."""
if PREFERENCES.prompt:
formatted_prompt = "\n" + prompt
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
logger.info(
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
)
if negative_prompt:
formatted_negative_prompt = "\n" + negative_prompt
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")
def log_response(response: str):
def log_response(request_id: str, response: str):
"""Logs the response to console."""
if PREFERENCES.prompt:
formatted_response = "\n" + response
logger.info(f"Response: {formatted_response if response else 'Empty'}\n")
logger.info(
f"Response (ID: {request_id}): "
f"{formatted_response if response else 'Empty'}\n"
)
def log_metrics(
request_id: str,
queue_time: float,
prompt_tokens: int,
cached_tokens: int,
@@ -80,7 +86,7 @@ def log_metrics(
max_seq_len: int,
):
initial_response = (
f"Metrics: {generated_tokens} tokens generated in "
f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in "
f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
)
itemization = []

View File

@@ -5,11 +5,14 @@ Containers exist as a common interface for backends.
"""
import pathlib
from enum import Enum
from fastapi import HTTPException
from loguru import logger
from typing import Optional
from common import config
from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.utils import unwrap
from endpoints.utils import do_export_openapi
@@ -18,6 +21,21 @@ if not do_export_openapi:
# Global model container
container: Optional[ExllamaV2Container] = None
embeddings_container = None
# Type hint the infinity emb container if it exists
from backends.infinity.model import has_infinity_emb
if has_infinity_emb:
from backends.infinity.model import InfinityContainer
embeddings_container: Optional[InfinityContainer] = None
class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"
def load_progress(module, modules):
@@ -25,11 +43,11 @@ def load_progress(module, modules):
yield module, modules
async def unload_model(skip_wait: bool = False):
async def unload_model(skip_wait: bool = False, shutdown: bool = False):
"""Unloads a model"""
global container
await container.unload(skip_wait=skip_wait)
await container.unload(skip_wait=skip_wait, shutdown=shutdown)
container = None
@@ -46,8 +64,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
f'Model "{loaded_model_name}" is already loaded! Aborting.'
)
# Unload the existing model
if container and container.model:
logger.info("Unloading existing model.")
await unload_model()
@@ -98,17 +114,89 @@ async def unload_loras():
await container.unload(loras_only=True)
def get_config_default(key, fallback=None, is_draft=False):
async def load_embedding_model(model_path: pathlib.Path, **kwargs):
global embeddings_container
# Break out if infinity isn't installed
if not has_infinity_emb:
raise ImportError(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)
# Check if the model is already loaded
if embeddings_container and embeddings_container.engine:
loaded_model_name = embeddings_container.model_dir.name
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
raise ValueError(
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
)
logger.info("Unloading existing embeddings model.")
await unload_embedding_model()
embeddings_container = InfinityContainer(model_path)
await embeddings_container.load(**kwargs)
async def unload_embedding_model():
global embeddings_container
await embeddings_container.unload()
embeddings_container = None
def get_config_default(key: str, fallback=None, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""
model_config = config.model_config()
default_keys = unwrap(model_config.get("use_as_default"), [])
# Add extra keys to defaults
default_keys.append("embeddings_device")
if key in default_keys:
# Is this a draft model load parameter?
if is_draft:
if model_type == "draft":
draft_config = config.draft_model_config()
return unwrap(draft_config.get(key), fallback)
elif model_type == "embedding":
embeddings_config = config.embeddings_config()
return unwrap(embeddings_config.get(key), fallback)
else:
return unwrap(model_config.get(key), fallback)
else:
return fallback
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
if container is None or not (container.model_is_loading or container.model_loaded):
error_message = handle_request_error(
"No models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
async def check_embeddings_container():
"""
FastAPI depends that checks if an embeddings model is loaded.
This is the same as the model container check, but with embeddings instead.
"""
if embeddings_container is None or not (
embeddings_container.model_is_loading or embeddings_container.model_loaded
):
error_message = handle_request_error(
"No embedding models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)

View File

@@ -1,12 +1,17 @@
"""Common utility functions"""
import asyncio
import json
import socket
import traceback
from fastapi import Request
from fastapi import Depends, HTTPException, Request
from loguru import logger
from pydantic import BaseModel
from typing import Optional
from uuid import uuid4
from common import config
from common.utils import unwrap
class TabbyRequestErrorMessage(BaseModel):
@@ -33,15 +38,18 @@ def get_generator_error(message: str, exc_info: bool = True):
def handle_request_error(message: str, exc_info: bool = True):
"""Log a request error to the console."""
trace = traceback.format_exc()
send_trace = unwrap(config.network_config().get("send_tracebacks"), False)
error_message = TabbyRequestErrorMessage(
message=message, trace=traceback.format_exc()
message=message, trace=trace if send_trace else None
)
request_error = TabbyRequestError(error=error_message)
# Log the error and provided message to the console
if error_message.trace and exc_info:
logger.error(error_message.trace)
if trace and exc_info:
logger.error(trace)
logger.error(f"Sent to request: {message}")
@@ -78,8 +86,9 @@ async def run_with_request_disconnect(
try:
return call_task.result()
except (asyncio.CancelledError, asyncio.InvalidStateError):
except (asyncio.CancelledError, asyncio.InvalidStateError) as ex:
handle_request_disconnect(disconnect_message)
raise HTTPException(422, disconnect_message) from ex
def is_port_in_use(port: int) -> bool:
@@ -93,3 +102,39 @@ def is_port_in_use(port: int) -> bool:
test_socket.settimeout(1)
with test_socket:
return test_socket.connect_ex(("localhost", port)) == 0
async def add_request_id(request: Request):
"""FastAPI depends to add a UUID to a request's state."""
request.state.id = uuid4().hex
return request
async def log_request(request: Request):
"""FastAPI depends to log a request to the user."""
log_message = [f"Information for {request.method} request {request.state.id}:"]
log_message.append(f"URL: {request.url}")
log_message.append(f"Headers: {dict(request.headers)}")
if request.method != "GET":
body_bytes = await request.body()
if body_bytes:
body = json.loads(body_bytes.decode("utf-8"))
log_message.append(f"Body: {dict(body)}")
logger.info("\n".join(log_message))
def get_global_depends():
"""Returns global dependencies for a FastAPI app."""
depends = [Depends(add_request_id)]
if config.logging_config().get("requests"):
depends.append(Depends(log_request))
return depends

View File

@@ -16,11 +16,15 @@ class BaseSamplerRequest(BaseModel):
max_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens"),
validation_alias=AliasChoices("max_tokens", "max_length"),
description="Aliases: max_length",
examples=[150],
)
min_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
validation_alias=AliasChoices("min_tokens", "min_length"),
description="Aliases: min_length",
examples=[0],
)
@@ -30,13 +34,22 @@ class BaseSamplerRequest(BaseModel):
)
stop: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("stop", [])
default_factory=lambda: get_default_sampler_value("stop", []),
validation_alias=AliasChoices("stop", "stop_sequence"),
description="Aliases: stop_sequence",
)
banned_strings: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("banned_strings", [])
)
banned_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",
examples=[[128, 330]],
)
token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
@@ -76,6 +89,13 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0],
)
typical: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
)
skew: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("skew", 0.0),
examples=[0.0],
@@ -91,9 +111,24 @@ class BaseSamplerRequest(BaseModel):
repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
validation_alias=AliasChoices("repetition_penalty", "rep_pen"),
description="Aliases: rep_pen",
examples=[1.0],
)
penalty_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
"rep_pen_range",
),
description=(
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range"
),
)
repetition_decay: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
@@ -118,6 +153,8 @@ class BaseSamplerRequest(BaseModel):
ban_eos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
validation_alias=AliasChoices("ban_eos_token", "ignore_eos"),
description="Aliases: ignore_eos",
examples=[False],
)
@@ -151,24 +188,6 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("speculative_ngram"),
)
# Aliased variables
typical: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
)
penalty_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
),
description="Aliases: repetition_range, repetition_penalty_range",
)
cfg_scale: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
@@ -196,13 +215,6 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0],
)
banned_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",
examples=[[128, 330]],
)
# TODO: Return back to adaptable class-based validation But that's just too much
# abstraction compared to simple if statements at the moment
def validate_params(self):

View File

@@ -1,13 +1,40 @@
import asyncio
import signal
import sys
from loguru import logger
from types import FrameType
from common import model
SHUTTING_DOWN: bool = False
def signal_handler(*_):
"""Signal handler for main function. Run before uvicorn starts."""
global SHUTTING_DOWN
if SHUTTING_DOWN:
return
logger.warning("Shutdown signal called. Exiting gracefully.")
SHUTTING_DOWN = True
# Run async unloads for model
asyncio.ensure_future(signal_handler_async())
async def signal_handler_async(*_):
"""Internal signal handler. Runs all async code to shut down the program."""
if model.container:
await model.unload_model(skip_wait=True, shutdown=True)
if model.embeddings_container:
await model.unload_embedding_model()
# Exit the program
sys.exit(0)

View File

@@ -51,7 +51,7 @@ class PromptTemplate:
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
"Please update jinja by running the following command: "
"pip install --upgrade jinja2"
)

View File

@@ -1,6 +1,7 @@
import json
import pathlib
from typing import List, Optional, Union
from loguru import logger
from pydantic import BaseModel
@@ -11,6 +12,7 @@ class GenerationConfig(BaseModel):
"""
eos_token_id: Optional[Union[int, List[int]]] = None
bad_words_ids: Optional[List[List[int]]] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
@@ -30,3 +32,38 @@ class GenerationConfig(BaseModel):
return [self.eos_token_id]
else:
return self.eos_token_id
class HuggingFaceConfig(BaseModel):
"""
An abridged version of HuggingFace's model config.
Will be expanded as needed.
"""
badwordsids: Optional[str] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
hf_config_dict = json.load(hf_config_json)
return self.model_validate(hf_config_dict)
def get_badwordsids(self):
"""Wrapper method to fetch badwordsids."""
if self.badwordsids:
try:
bad_words_list = json.loads(self.badwordsids)
return bad_words_list
except json.JSONDecodeError:
logger.warning(
"Skipping badwordsids from config.json "
"since it's not a valid array."
)
return []
else:
return []

View File

@@ -18,3 +18,9 @@ def prune_dict(input_dict):
"""Trim out instances of None from a dictionary."""
return {k: v for k, v in input_dict.items() if v is not None}
def flat_map(input_list):
"""Flattens a list of lists into a single list."""
return [item for sublist in input_list for item in sublist]

View File

@@ -19,6 +19,14 @@ network:
# Turn on this option if you are ONLY connecting from localhost
disable_auth: False
# Send tracebacks over the API to clients (default: False)
# NOTE: Only enable this for debug purposes
send_tracebacks: False
# Select API servers to enable (default: ["OAI"])
# Possible values: OAI
api_servers: ["OAI"]
# Options for logging
logging:
# Enable prompt logging (default: False)
@@ -27,6 +35,10 @@ logging:
# Enable generation parameter logging (default: False)
generation_params: False
# Enable request logging (default: False)
# NOTE: Only use this for debugging!
requests: False
# Options for sampling
sampling:
# Override preset name. Find this in the sampler-overrides folder (default: None)
@@ -50,6 +62,16 @@ developer:
# This can save a few MBs of VRAM, but has a risk of errors. Use at your own risk.
#cuda_malloc_backend: False
# Enable Uvloop or Winloop (default: False)
# Make the program utilize a faster async event loop which can improve performance
# NOTE: It's recommended to enable this, but if something breaks, turn this off.
#uvloop: False
# Set process to use a higher priority
# For realtime process priority, run as administrator or sudo
# Otherwise, the priority will be set to high
#realtime_process_priority: False
# Options for model overrides and loading
# Please read the comments to understand how arguments are handled between initial and API loads
model:
@@ -124,11 +146,11 @@ model:
# NOTE: Effects vary depending on the model. An ideal value is between 512 and 4096
#chunk_size: 2048
# Set the maximum amount of prompts to process at one time (batch)
# This will be automatically adjusted depending on the cache size.
# Set the maximum amount of prompts to process at one time (default: None/Automatic)
# This will be automatically calculated if left blank.
# A max batch size of 1 processes prompts one at a time.
# NOTE: Only available for Nvidia ampere (30 series) and above GPUs
#max_batch_size: 20
#max_batch_size:
# Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None)
# If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name
@@ -179,3 +201,22 @@ model:
#loras:
#- name: lora1
# scaling: 1.0
# Options for embedding models and loading.
# NOTE: Embeddings requires the "extras" feature to be installed
# Install it via "pip install .[extras]"
embeddings:
# Overrides directory to look for embedding models (default: models)
embedding_model_dir: models
# Device to load embedding models on (default: cpu)
# Possible values: cpu, auto, cuda
# NOTE: It's recommended to load embedding models on the CPU.
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
embeddings_device: cpu
# The below parameters only apply for initial loads
# All API based loads do NOT inherit these settings unless specified in use_as_default
# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:

View File

@@ -12,7 +12,7 @@ services:
- NAME=TabbyAPI
- NVIDIA_VISIBLE_DEVICES=all
volumes:
- ./models:/usr/src/app/models
- ./models:/app/models
deploy:
resources:
reservations:

161
endpoints/Kobold/router.py Normal file
View File

@@ -0,0 +1,161 @@
from sys import maxsize
from fastapi import APIRouter, Depends, Request
from sse_starlette import EventSourceResponse
from common import model
from common.auth import check_api_key
from common.model import check_model_container
from common.utils import unwrap
from endpoints.core.utils.model import get_current_model
from endpoints.Kobold.types.generation import (
AbortRequest,
AbortResponse,
CheckGenerateRequest,
GenerateRequest,
GenerateResponse,
)
from endpoints.Kobold.types.model import CurrentModelResponse, MaxLengthResponse
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
from endpoints.Kobold.utils.generation import (
abort_generation,
generation_status,
get_generation,
stream_generation,
)
api_name = "KoboldAI"
router = APIRouter(prefix="/api")
urls = {
"Generation": "http://{host}:{port}/api/v1/generate",
"Streaming": "http://{host}:{port}/api/extra/generate/stream",
}
kai_router = APIRouter()
extra_kai_router = APIRouter()
def setup():
router.include_router(kai_router, prefix="/v1")
router.include_router(kai_router, prefix="/latest", include_in_schema=False)
router.include_router(extra_kai_router, prefix="/extra")
return router
@kai_router.post(
"/generate",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
response = await get_generation(data, request)
return response
@extra_kai_router.post(
"/generate/stream",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
response = EventSourceResponse(stream_generation(data, request), ping=maxsize)
return response
@extra_kai_router.post(
"/abort",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def abort_generate(data: AbortRequest) -> AbortResponse:
response = await abort_generation(data.genkey)
return response
@extra_kai_router.get(
"/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@extra_kai_router.post(
"/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
response = await generation_status(data.genkey)
return response
@kai_router.get(
"/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
)
async def current_model() -> CurrentModelResponse:
"""Fetches the current model and who owns it."""
current_model_card = get_current_model()
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
@extra_kai_router.post(
"/tokencount",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
raw_tokens = model.container.encode_tokens(data.prompt)
tokens = unwrap(raw_tokens, [])
return TokenCountResponse(value=len(tokens), ids=tokens)
@kai_router.get(
"/config/max_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@kai_router.get(
"/config/max_context_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@extra_kai_router.get(
"/true_max_context_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_max_length() -> MaxLengthResponse:
"""Fetches the max length of the model."""
max_length = model.container.get_model_parameters().get("max_seq_len")
return {"value": max_length}
@kai_router.get("/info/version")
async def get_version():
"""Impersonate KAI United."""
return {"result": "1.2.5"}
@extra_kai_router.get("/version")
async def get_extra_version():
"""Impersonate Koboldcpp."""
return {"result": "KoboldCpp", "version": "1.61"}
@kai_router.get("/config/soft_prompts_list")
async def get_available_softprompts():
"""Used for KAI compliance."""
return {"values": []}
@kai_router.get("/config/soft_prompt")
async def get_current_softprompt():
"""Used for KAI compliance."""
return {"value": ""}
@kai_router.put("/config/soft_prompt")
async def set_current_softprompt():
"""Used for KAI compliance."""
return {}

View File

@@ -0,0 +1,60 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from common import model
from common.sampling import BaseSamplerRequest, get_default_sampler_value
from common.utils import flat_map, unwrap
class GenerateRequest(BaseSamplerRequest):
prompt: str
genkey: Optional[str] = None
use_default_badwordsids: Optional[bool] = False
dynatemp_range: Optional[float] = Field(
default_factory=get_default_sampler_value("dynatemp_range")
)
def to_gen_params(self, **kwargs):
# Exl2 uses -1 to include all tokens in repetition penalty
if self.penalty_range == 0:
self.penalty_range = -1
if self.dynatemp_range:
self.min_temp = self.temperature - self.dynatemp_range
self.max_temp = self.temperature + self.dynatemp_range
# Move badwordsids into banned tokens for generation
if self.use_default_badwordsids:
bad_words_ids = unwrap(
model.container.generation_config.bad_words_ids,
model.container.hf_config.get_badwordsids(),
)
if bad_words_ids:
self.banned_tokens += flat_map(bad_words_ids)
return super().to_gen_params(**kwargs)
class GenerateResponseResult(BaseModel):
text: str
class GenerateResponse(BaseModel):
results: List[GenerateResponseResult] = Field(default_factory=list)
class StreamGenerateChunk(BaseModel):
token: str
class AbortRequest(BaseModel):
genkey: str
class AbortResponse(BaseModel):
success: bool
class CheckGenerateRequest(BaseModel):
genkey: str

View File

@@ -0,0 +1,9 @@
from pydantic import BaseModel
class CurrentModelResponse(BaseModel):
result: str
class MaxLengthResponse(BaseModel):
value: int

View File

@@ -0,0 +1,15 @@
from pydantic import BaseModel
from typing import List
class TokenCountRequest(BaseModel):
"""Represents a KAI tokenization request."""
prompt: str
class TokenCountResponse(BaseModel):
"""Represents a KAI tokenization response."""
value: int
ids: List[int]

View File

@@ -0,0 +1,151 @@
import asyncio
from asyncio import CancelledError
from fastapi import HTTPException, Request
from loguru import logger
from sse_starlette import ServerSentEvent
from common import model
from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.utils import unwrap
from endpoints.Kobold.types.generation import (
AbortResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseResult,
StreamGenerateChunk,
)
generation_cache = {}
async def override_request_id(request: Request, data: GenerateRequest):
"""Overrides the request ID with a KAI genkey if present."""
if data.genkey:
request.state.id = data.genkey
def _create_response(text: str):
results = [GenerateResponseResult(text=text)]
return GenerateResponse(results=results)
def _create_stream_chunk(text: str):
return StreamGenerateChunk(token=text)
async def _stream_collector(data: GenerateRequest, request: Request):
"""Common async generator for generation streams."""
abort_event = asyncio.Event()
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
# Create a new entry in the cache
generation_cache[data.genkey] = {"abort": abort_event, "text": ""}
try:
logger.info(f"Received Kobold generation request {data.genkey}")
generator = model.container.generate_gen(
data.prompt, data.genkey, abort_event, **data.to_gen_params()
)
async for generation in generator:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect(
f"Kobold generation {data.genkey} cancelled by user."
)
text = generation.get("text")
# Update the generation cache with the new chunk
if text:
generation_cache[data.genkey]["text"] += text
yield text
if "finish_reason" in generation:
logger.info(f"Finished streaming Kobold request {data.genkey}")
break
except CancelledError:
# If the request disconnects, break out
if not disconnect_task.done():
abort_event.set()
handle_request_disconnect(
f"Kobold generation {data.genkey} cancelled by user."
)
finally:
# Cleanup the cache
del generation_cache[data.genkey]
async def stream_generation(data: GenerateRequest, request: Request):
"""Wrapper for stream generations."""
# If the genkey doesn't exist, set it to the request ID
if not data.genkey:
data.genkey = request.state.id
try:
async for chunk in _stream_collector(data, request):
response = _create_stream_chunk(chunk)
yield ServerSentEvent(
event="message", data=response.model_dump_json(), sep="\n"
)
except Exception:
yield get_generator_error(
f"Kobold generation {data.genkey} aborted. "
"Please check the server console."
)
async def get_generation(data: GenerateRequest, request: Request):
"""Wrapper to get a static generation."""
# If the genkey doesn't exist, set it to the request ID
if not data.genkey:
data.genkey = request.state.id
try:
full_text = ""
async for chunk in _stream_collector(data, request):
full_text += chunk
response = _create_response(full_text)
return response
except Exception as exc:
error_message = handle_request_error(
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc
async def abort_generation(genkey: str):
"""Aborts a generation from the cache."""
abort_event = unwrap(generation_cache.get(genkey), {}).get("abort")
if abort_event:
abort_event.set()
handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.")
return AbortResponse(success=True)
async def generation_status(genkey: str):
"""Fetches the status of a generation from the cache."""
current_text = unwrap(generation_cache.get(genkey), {}).get("text")
if current_text:
response = _create_response(current_text)
else:
response = GenerateResponse()
return response

View File

@@ -1,48 +1,21 @@
import asyncio
import pathlib
from loguru import logger
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from typing import Optional
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key, validate_key_permission
from common.downloader import hf_repo_download
from common import config, model
from common.auth import check_api_key
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import coalesce, unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
from endpoints.OAI.types.lora import (
LoraCard,
LoraList,
LoraLoadRequest,
LoraLoadResponse,
)
from endpoints.OAI.types.model import (
ModelCard,
ModelList,
ModelLoadRequest,
ModelCardParameters,
ModelLoadResponse,
)
from endpoints.OAI.types.sampler_overrides import (
SamplerOverrideListResponse,
SamplerOverrideSwitchRequest,
)
from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest
from endpoints.OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
TokenDecodeRequest,
TokenDecodeResponse,
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
@@ -52,25 +25,19 @@ from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list
from endpoints.OAI.utils.embeddings import get_embeddings
api_name = "OAI"
router = APIRouter()
urls = {
"Completions": "http://{host}:{port}/v1/completions",
"Chat completions": "http://{host}:{port}/v1/chat/completions",
}
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
if model.container is None or not (
model.container.model_is_loading or model.container.model_loaded
):
error_message = handle_request_error(
"No models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
def setup():
return router
# Completions endpoint
@@ -106,12 +73,14 @@ async def completion_request(
ping=maxsize,
)
else:
generate_task = asyncio.create_task(generate_completion(data, model_path))
generate_task = asyncio.create_task(
generate_completion(data, request, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Completion generation cancelled by user.",
disconnect_message=f"Completion {request.state.id} cancelled by user.",
)
return response
@@ -206,392 +175,28 @@ async def chat_completion_request(
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, model_path)
generate_chat_completion(prompt, data, request, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Chat completion generation cancelled by user.",
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models() -> ModelList:
"""Lists all models in the model directory."""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model_config().get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
# Currently loaded model endpoint
@router.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_current_model() -> ModelCard:
"""Returns the currently loaded model."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models() -> ModelList:
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
return models
# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""
# Verify request parameters
if not data.name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name
draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
if not model_path.exists():
error_message = handle_request_error(
"Could not find the model path for load. Check model name or config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return EventSourceResponse(
stream_model_load(data, model_path, draft_model_path), ping=maxsize
)
# Unload model endpoint
# Embeddings endpoint
@router.post(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def unload_model():
"""Unloads the currently loaded model."""
await model.unload_model(skip_wait=True)
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
"""Downloads a model from HuggingFace."""
try:
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
return DownloadResponse(download_path=str(download_path))
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Lora list endpoint
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras() -> LoraList:
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
return loras
# Currently loaded loras endpoint
@router.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_active_loras() -> LoraList:
"""Returns the currently loaded loras."""
active_loras = LoraList(
data=[
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
for lora in model.container.get_loras()
]
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
embeddings_task,
f"Embeddings request {request.state.id} cancelled by user.",
)
return active_loras
# Load lora endpoint
@router.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
"""Loads a LoRA into the model container."""
if not data.loras:
error_message = handle_request_error(
"List of loras to load is not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
load_result = await model.load_loras(
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
)
return LoraLoadResponse(
success=unwrap(load_result.get("success"), []),
failure=unwrap(load_result.get("failure"), []),
)
# Unload lora endpoint
@router.post(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_loras():
"""Unloads the currently loaded loras."""
await model.unload_loras()
# Encode tokens endpoint
@router.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
if isinstance(data.text, str):
text = data.text
else:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}
text, _ = model.container.prompt_template.render(template_vars)
raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response
# Decode tokens endpoint
@router.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
"""Decodes tokens into a string."""
message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None),
) -> AuthPermissionResponse:
"""
Gets the access level/permission of a provided key in headers.
Priority:
- X-api-key
- X-admin-key
- Authorization
"""
test_key = coalesce(x_admin_key, x_api_key, authorization)
try:
permission = await validate_key_permission(test_key)
return AuthPermissionResponse(permission=permission)
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates() -> TemplateList:
templates = get_all_templates()
template_strings = [template.stem for template in templates]
return TemplateList(data=template_strings)
@router.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
model.container.prompt_template = None
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides() -> SamplerOverrideListResponse:
"""API wrapper to list all currently applied sampler overrides"""
return SamplerOverrideListResponse(
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
)
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
+ "Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
elif data.overrides:
sampling.overrides_from_dict(data.overrides)
else:
error_message = handle_request_error(
"A sampler override preset or dictionary wasn't provided.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
sampling.overrides_from_dict({})

View File

@@ -65,3 +65,4 @@ class ChatCompletionStreamChunk(BaseModel):
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None

View File

@@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel):
type: str = "text"
class ChatCompletionStreamOptions(BaseModel):
include_usage: Optional[bool] = False
class CommonCompletionRequest(BaseSamplerRequest):
"""Represents a common completion request."""
@@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
stream_options: Optional[ChatCompletionStreamOptions] = None
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0)
)

View File

@@ -0,0 +1,42 @@
from typing import List, Optional
from pydantic import BaseModel, Field
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class EmbeddingsRequest(BaseModel):
input: List[str] = Field(
..., description="List of input texts to generate embeddings for."
)
encoding_format: str = Field(
"float",
description="Encoding format for the embeddings. "
"Can be 'float' or 'base64'.",
)
model: Optional[str] = Field(
None,
description="Name of the embedding model to use. "
"If not provided, the default model will be used.",
)
class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.")
embedding: List[float] = Field(
..., description="Embedding values as a list of floats."
)
index: int = Field(
..., description="Index of the input text corresponding to " "the embedding."
)
class EmbeddingsResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
model: str = Field(..., description="Name of the embedding model used.")
usage: UsageInfo = Field(..., description="Information about token usage.")

View File

@@ -5,7 +5,6 @@ import pathlib
from asyncio import CancelledError
from copy import deepcopy
from typing import List, Optional
from uuid import uuid4
from fastapi import HTTPException, Request
from jinja2 import TemplateError
@@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionStreamChoice,
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
def _create_response(generations: List[dict], model_name: Optional[str]):
def _create_response(
request_id: str, generations: List[dict], model_name: Optional[str]
):
"""Create a chat completion response from the provided text."""
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
@@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
choices.append(choice)
response = ChatCompletionResponse(
id=f"chatcmpl-{request_id}",
choices=choices,
model=unwrap(model_name, ""),
usage=UsageStats(
@@ -90,25 +93,40 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
def _create_stream_chunk(
const_id: str,
request_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
is_usage_chunk: bool = False,
):
"""Create a chat completion stream chunk from the provided text."""
index = generation.get("index")
logprob_response = None
choices = []
usage_stats = None
if "finish_reason" in generation:
if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
usage_stats = UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
elif "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
index=index,
finish_reason=generation.get("finish_reason"),
)
choices.append(choice)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
)
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
@@ -132,8 +150,13 @@ def _create_stream_chunk(
logprobs=logprob_response,
)
choices.append(choice)
chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
id=f"chatcmpl-{request_id}",
choices=choices,
model=unwrap(model_name, ""),
usage=usage_stats,
)
return chunk
@@ -215,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest):
raise HTTPException(400, error_message) from exc
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
if "finish_reason" in generation:
break
except Exception as e:
await gen_queue.put(e)
async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
"""Generator for the generation process."""
const_id = f"chatcmpl-{uuid4().hex}"
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
logger.info(f"Received chat completion streaming request {request.state.id}")
gen_params = data.to_gen_params()
for n in range(0, data.n):
@@ -257,7 +259,14 @@ async def stream_generate_chat_completion(
task_gen_params = gen_params
gen_task = asyncio.create_task(
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
_stream_collector(
n,
gen_queue,
prompt,
request.state.id,
abort_event,
**task_gen_params,
)
)
gen_tasks.append(gen_task)
@@ -266,7 +275,9 @@ async def stream_generate_chat_completion(
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
handle_request_disconnect(
f"Chat completion generation {request.state.id} cancelled by user."
)
generation = await gen_queue.get()
@@ -274,17 +285,35 @@ async def stream_generate_chat_completion(
if isinstance(generation, Exception):
raise generation
response = _create_stream_chunk(const_id, generation, model_path.name)
response = _create_stream_chunk(
request.state.id, generation, model_path.name
)
yield response.model_dump_json()
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
# Send a usage chunk
if data.stream_options and data.stream_options.include_usage:
usage_chunk = _create_stream_chunk(
request.state.id,
generation,
model_path.name,
is_usage_chunk=True,
)
yield usage_chunk.model_dump_json()
logger.info(
f"Finished chat completion streaming request {request.state.id}"
)
yield "[DONE]"
break
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Chat completion generation cancelled by user.")
if not disconnect_task.done():
abort_event.set()
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Chat completion aborted. Please check the server console."
@@ -292,7 +321,7 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params()
@@ -307,16 +336,23 @@ async def generate_chat_completion(
task_gen_params = gen_params
gen_tasks.append(
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
asyncio.create_task(
model.container.generate(
prompt, request.state.id, **task_gen_params
)
)
)
generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name)
response = _create_response(request.state.id, generations, model_path.name)
logger.info(f"Finished chat completion request {request.state.id}")
return response
except Exception as exc:
error_message = handle_request_error(
"Chat completion aborted. Maybe the model was unloaded? "
f"Chat completion {request.state.id} aborted. "
"Maybe the model was unloaded? "
"Please check the server console."
).error.message

View File

@@ -7,6 +7,8 @@ from copy import deepcopy
from fastapi import HTTPException, Request
from typing import List, Union
from loguru import logger
from common import model
from common.networking import (
get_generator_error,
@@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import (
from endpoints.OAI.types.common import UsageStats
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
def _create_response(
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
):
"""Create a completion response from the provided choices."""
# Convert the single choice object into a list
@@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
response = CompletionResponse(
id=f"cmpl-{request_id}",
choices=choices,
model=model_name,
usage=UsageStats(
@@ -77,13 +82,16 @@ async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
request_id: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs
)
async for generation in new_generation:
generation["index"] = task_idx
@@ -106,6 +114,8 @@ async def stream_generate_completion(
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
logger.info(f"Received streaming completion request {request.state.id}")
gen_params = data.to_gen_params()
for n in range(0, data.n):
@@ -116,7 +126,12 @@ async def stream_generate_completion(
gen_task = asyncio.create_task(
_stream_collector(
n, gen_queue, data.prompt, abort_event, **task_gen_params
n,
gen_queue,
data.prompt,
request.state.id,
abort_event,
**task_gen_params,
)
)
@@ -126,7 +141,9 @@ async def stream_generate_completion(
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
handle_request_disconnect(
f"Completion generation {request.state.id} cancelled by user."
)
generation = await gen_queue.get()
@@ -134,31 +151,39 @@ async def stream_generate_completion(
if isinstance(generation, Exception):
raise generation
response = _create_response(generation, model_path.name)
response = _create_response(request.state.id, generation, model_path.name)
yield response.model_dump_json()
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
yield "[DONE]"
logger.info(f"Finished streaming completion request {request.state.id}")
break
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
if not disconnect_task.done():
abort_event.set()
handle_request_disconnect(
f"Completion generation {request.state.id} cancelled by user."
)
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
f"Completion {request.state.id} aborted. Please check the server console."
)
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
async def generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):
"""Non-streaming generate for completions"""
gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params()
try:
logger.info(f"Recieved completion request {request.state.id}")
for n in range(0, data.n):
# Deepcopy gen params above the first index
# to ensure nested structures aren't shared
@@ -169,17 +194,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
gen_tasks.append(
asyncio.create_task(
model.container.generate(data.prompt, **task_gen_params)
model.container.generate(
data.prompt, request.state.id, **task_gen_params
)
)
)
generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name)
response = _create_response(request.state.id, generations, model_path.name)
logger.info(f"Finished completion request {request.state.id}")
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? "
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message

View File

@@ -0,0 +1,64 @@
"""
This file is derived from
[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
and modified.
The changes introduced are: Suppression of progress bar,
typing/pydantic classes moved into this file,
embeddings function declared async.
"""
import base64
from fastapi import Request
import numpy as np
from loguru import logger
from common import model
from endpoints.OAI.types.embedding import (
EmbeddingObject,
EmbeddingsRequest,
EmbeddingsResponse,
UsageInfo,
)
def float_list_to_base64(float_array: np.ndarray) -> str:
"""
Converts the provided list to a float32 array for OpenAI
Ex. float_array = np.array(float_list, dtype="float32")
"""
# Encode raw bytes into base64
encoded_bytes = base64.b64encode(float_array.tobytes())
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode("ascii")
return ascii_string
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
model_path = model.embeddings_container.model_dir
logger.info(f"Recieved embeddings request {request.state.id}")
embedding_data = await model.embeddings_container.generate(data.input)
# OAI expects a return of base64 if the input is base64
embedding_object = [
EmbeddingObject(
embedding=float_list_to_base64(emb)
if data.encoding_format == "base64"
else emb.tolist(),
index=n,
)
for n, emb in enumerate(embedding_data.get("embeddings"))
]
usage = embedding_data.get("usage")
response = EmbeddingsResponse(
data=embedding_object,
model=model_path.name,
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
)
logger.info(f"Finished embeddings request {request.state.id}")
return response

View File

@@ -1,14 +0,0 @@
import pathlib
from endpoints.OAI.types.lora import LoraCard, LoraList
def get_lora_list(lora_path: pathlib.Path):
"""Get the list of Lora cards from the provided path."""
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card)
return lora_list

521
endpoints/core/router.py Normal file
View File

@@ -0,0 +1,521 @@
import asyncio
import pathlib
from sys import maxsize
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from common import config, model, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from endpoints.core.types.auth import AuthPermissionResponse
from endpoints.core.types.download import DownloadRequest, DownloadResponse
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
from endpoints.core.types.model import (
EmbeddingModelLoadRequest,
ModelCard,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
)
from endpoints.core.types.sampler_overrides import (
SamplerOverrideListResponse,
SamplerOverrideSwitchRequest,
)
from endpoints.core.types.template import TemplateList, TemplateSwitchRequest
from endpoints.core.types.token import (
TokenDecodeRequest,
TokenDecodeResponse,
TokenEncodeRequest,
TokenEncodeResponse,
)
from endpoints.core.utils.lora import get_active_loras, get_lora_list
from endpoints.core.utils.model import (
get_current_model,
get_current_model_list,
get_model_list,
stream_model_load,
)
router = APIRouter()
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models(request: Request) -> ModelList:
"""
Lists all models in the model directory.
Requires an admin key to see all models.
"""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model_config().get("draft_model_dir")
if get_key_permission(request) == "admin":
models = get_model_list(model_path.resolve(), draft_model_dir)
else:
models = await get_current_model_list()
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
# Currently loaded model endpoint
@router.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def current_model() -> ModelCard:
"""Returns the currently loaded model."""
return get_current_model()
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models(request: Request) -> ModelList:
"""
Lists all draft models in the model directory.
Requires an admin key to see all draft models.
"""
if get_key_permission(request) == "admin":
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
else:
models = await get_current_model_list(is_draft=True)
return models
# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""
# Verify request parameters
if not data.name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name
draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
if not model_path.exists():
error_message = handle_request_error(
"Could not find the model path for load. Check model name or config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return EventSourceResponse(
stream_model_load(data, model_path, draft_model_path), ping=maxsize
)
# Unload model endpoint
@router.post(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_model():
"""Unloads the currently loaded model."""
await model.unload_model(skip_wait=True)
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
"""Downloads a model from HuggingFace."""
try:
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
return DownloadResponse(download_path=str(download_path))
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Lora list endpoint
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def list_all_loras(request: Request) -> LoraList:
"""
Lists all LoRAs in the lora directory.
Requires an admin key to see all LoRAs.
"""
if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
else:
loras = get_active_loras()
return loras
# Currently loaded loras endpoint
@router.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def active_loras() -> LoraList:
"""Returns the currently loaded loras."""
return get_active_loras()
# Load lora endpoint
@router.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
"""Loads a LoRA into the model container."""
if not data.loras:
error_message = handle_request_error(
"List of loras to load is not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
load_result = await model.load_loras(
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
)
return LoraLoadResponse(
success=unwrap(load_result.get("success"), []),
failure=unwrap(load_result.get("failure"), []),
)
# Unload lora endpoint
@router.post(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_loras():
"""Unloads the currently loaded loras."""
await model.unload_loras()
@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)])
async def list_embedding_models(request: Request) -> ModelList:
"""
Lists all embedding models in the model directory.
Requires an admin key to see all embedding models.
"""
if get_key_permission(request) == "admin":
embedding_model_dir = unwrap(
config.embeddings_config().get("embedding_model_dir"), "models"
)
embedding_model_path = pathlib.Path(embedding_model_dir)
models = get_model_list(embedding_model_path.resolve())
else:
models = await get_current_model_list(model_type="embedding")
return models
@router.get(
"/v1/model/embedding",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def get_embedding_model() -> ModelCard:
"""Returns the currently loaded embedding model."""
models = await get_current_model_list(model_type="embedding")
return models.data[0]
@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)])
async def load_embedding_model(
request: Request, data: EmbeddingModelLoadRequest
) -> ModelLoadResponse:
# Verify request parameters
if not data.name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(
unwrap(config.model_config().get("embedding_model_dir"), "models")
)
embedding_model_path = embedding_model_dir / data.name
if not embedding_model_path.exists():
error_message = handle_request_error(
"Could not find the embedding model path for load. "
+ "Check model name or config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
load_task = asyncio.create_task(
model.load_embedding_model(embedding_model_path, **data.model_dump())
)
await run_with_request_disconnect(
request, load_task, "Embedding model load request cancelled by user."
)
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
response = ModelLoadResponse(
model_type="embedding_model", module=1, modules=1, status="finished"
)
return response
@router.post(
"/v1/model/embedding/unload",
dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)],
)
async def unload_embedding_model():
"""Unloads the current embedding model."""
await model.unload_embedding_model()
# Encode tokens endpoint
@router.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
if isinstance(data.text, str):
text = data.text
else:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}
text, _ = model.container.prompt_template.render(template_vars)
raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response
# Decode tokens endpoint
@router.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
"""Decodes tokens into a string."""
message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def key_permission(request: Request) -> AuthPermissionResponse:
"""
Gets the access level/permission of a provided key in headers.
Priority:
- X-admin-key
- X-api-key
- Authorization
"""
try:
permission = get_key_permission(request)
return AuthPermissionResponse(permission=permission)
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def list_templates(request: Request) -> TemplateList:
"""
Get a list of all templates.
Requires an admin key to see all templates.
"""
template_strings = []
if get_key_permission(request) == "admin":
templates = get_all_templates()
template_strings = [template.stem for template in templates]
else:
if model.container and model.container.prompt_template:
template_strings.append(model.container.prompt_template.name)
return TemplateList(data=template_strings)
@router.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template."""
if not data.name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
model.container.prompt_template = None
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse:
"""
List all currently applied sampler overrides.
Requires an admin key to see all override presets.
"""
if get_key_permission(request) == "admin":
presets = sampling.get_all_presets()
else:
presets = []
return SamplerOverrideListResponse(
presets=presets, **sampling.overrides_container.model_dump()
)
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
+ "Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
elif data.overrides:
sampling.overrides_from_dict(data.overrides)
else:
error_message = handle_request_error(
"A sampler override preset or dictionary wasn't provided.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
sampling.overrides_from_dict({})

View File

@@ -17,6 +17,7 @@ class DownloadRequest(BaseModel):
include: List[str] = Field(default_factory=_generate_include_list)
exclude: List[str] = Field(default_factory=list)
chunk_limit: Optional[int] = None
timeout: Optional[int] = None
class DownloadResponse(BaseModel):

View File

@@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
# Config arguments
draft_rope_scale: Optional[float] = Field(
default_factory=lambda: get_config_default(
"draft_rope_scale", 1.0, is_draft=True
"draft_rope_scale", 1.0, model_type="draft"
)
)
draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
default_factory=lambda: get_config_default(
"draft_rope_alpha", None, is_draft=True
"draft_rope_alpha", None, model_type="draft"
),
examples=[1.0],
)
draft_cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default(
"draft_cache_mode", "FP16", is_draft=True
"draft_cache_mode", "FP16", model_type="draft"
)
)
@@ -137,6 +137,15 @@ class ModelLoadRequest(BaseModel):
skip_queue: Optional[bool] = False
class EmbeddingModelLoadRequest(BaseModel):
name: str
embeddings_device: Optional[str] = Field(
default_factory=lambda: get_config_default(
"embeddings_device", model_type="embedding"
)
)
class ModelLoadResponse(BaseModel):
"""Represents a model load response."""

View File

@@ -0,0 +1,30 @@
import pathlib
from common import model
from endpoints.core.types.lora import LoraCard, LoraList
def get_lora_list(lora_path: pathlib.Path):
"""Get the list of Lora cards from the provided path."""
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card)
return lora_list
def get_active_loras():
if model.container:
active_loras = [
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
for lora in model.container.get_loras()
]
else:
active_loras = []
return LoraList(data=active_loras)

View File

@@ -2,11 +2,12 @@ import pathlib
from asyncio import CancelledError
from typing import Optional
from common import model
from common import gen_logging, model
from common.networking import get_generator_error, handle_request_disconnect
from endpoints.OAI.types.model import (
from common.utils import unwrap
from endpoints.core.types.model import (
ModelCard,
ModelCardParameters,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
@@ -31,6 +32,61 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
return model_card_list
async def get_current_model_list(model_type: str = "model"):
"""
Gets the current model in list format and with path only.
Unified for fetching both models and embedding models.
"""
current_models = []
model_path = None
# Make sure the model container exists
if model_type == "model" or model_type == "draft":
if model.container:
model_path = model.container.get_model_path(model_type == "draft")
elif model_type == "embedding":
if model.embeddings_container:
model_path = model.embeddings_container.model_dir
if model_path:
current_models.append(ModelCard(id=model_path.name))
return ModelList(data=current_models)
def get_current_model():
"""Gets the current model with all parameters."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
async def stream_model_load(
data: ModelLoadRequest,
model_path: pathlib.Path,

View File

@@ -1,40 +1,74 @@
import asyncio
from typing import Optional
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from common import config
from common.logger import UVICORN_LOG_CONFIG
from endpoints.OAI.router import router as OAIRouter
app = FastAPI(
title="TabbyAPI",
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
description=(
"This docs page is not meant to send requests! Please use a service "
"like Postman or a frontend UI."
),
)
# ALlow CORS requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from common.networking import get_global_depends
from common.utils import unwrap
from endpoints.Kobold import router as KoboldRouter
from endpoints.OAI import router as OAIRouter
from endpoints.core.router import router as CoreRouter
def setup_app():
def setup_app(host: Optional[str] = None, port: Optional[int] = None):
"""Includes the correct routers for startup"""
app.include_router(OAIRouter)
app = FastAPI(
title="TabbyAPI",
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
description=(
"This docs page is not meant to send requests! Please use a service "
"like Postman or a frontend UI."
),
dependencies=get_global_depends(),
)
# ALlow CORS requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
api_servers = unwrap(config.network_config().get("api_servers"), [])
# Map for API id to server router
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
# Include the OAI api by default
if api_servers:
for server in api_servers:
selected_server = router_mapping.get(server.lower())
if selected_server:
app.include_router(selected_server.setup())
logger.info(f"Starting {selected_server.api_name} API")
for path, url in selected_server.urls.items():
formatted_url = url.format(host=host, port=port)
logger.info(f"{path}: {formatted_url}")
else:
app.include_router(OAIRouter.setup())
for path, url in OAIRouter.urls.items():
formatted_url = url.format(host=host, port=port)
logger.info(f"{path}: {formatted_url}")
# Include core API request paths
app.include_router(CoreRouter)
return app
def export_openapi():
"""Function to return the OpenAPI JSON from the API server"""
setup_app()
app = setup_app()
return app.openapi()
@@ -43,17 +77,21 @@ async def start_api(host: str, port: int):
# TODO: Move OAI API to a separate folder
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
# logger.info(f"Completions: http://{host}:{port}/v1/completions")
# logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
# Setup app
setup_app()
app = setup_app(host, port)
# Get the current event loop
loop = asyncio.get_running_loop()
config = uvicorn.Config(
app,
host=host,
port=port,
log_config=UVICORN_LOG_CONFIG,
loop=loop,
)
server = uvicorn.Server(config)

143
main.py
View File

@@ -1,10 +1,10 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import asyncio
import aiofiles
import json
import os
import pathlib
import platform
import signal
from loguru import logger
from typing import Optional
@@ -23,51 +23,8 @@ if not do_export_openapi:
from backends.exllamav2.utils import check_exllama_version
async def entrypoint(args: Optional[dict] = None):
"""Entry function for program startup"""
setup_logger()
# Set up signal aborting
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if os.getenv("EXPORT_OPENAPI", "").lower() in ("true", "1"):
openapi_json = export_openapi()
async with aiofiles.open("openapi.json", "w") as f:
await f.write(json.dumps(openapi_json))
logger.info("Successfully wrote OpenAPI spec to openapi.json")
return
# Load from YAML config
config.from_file(pathlib.Path("config.yml"))
# Parse and override config from args
if args is None:
parser = init_argparser()
args = convert_args_to_dict(parser.parse_args(), parser)
config.from_args(args)
developer_config = config.developer_config()
# Check exllamav2 version and give a descriptive error if it's too old
# Skip if launching unsafely
if unwrap(developer_config.get("unsafe_launch"), False):
logger.warning(
"UNSAFE: Skipping ExllamaV2 version check.\n"
"If you aren't a developer, please keep this off!"
)
else:
check_exllama_version()
# Enable CUDA malloc backend
if unwrap(developer_config.get("cuda_malloc_backend"), False):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
logger.warning("Enabled the experimental CUDA malloc backend.")
async def entrypoint_async():
"""Async entry function for program startup"""
network_config = config.network_config()
@@ -97,7 +54,7 @@ async def entrypoint(args: Optional[dict] = None):
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
# Override the generation log options if given
log_config = config.gen_logging_config()
log_config = config.logging_config()
if log_config:
gen_logging.update_from_dict(log_config)
@@ -128,8 +85,98 @@ async def entrypoint(args: Optional[dict] = None):
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
await model.container.load_loras(lora_dir.resolve(), **lora_config)
# If an initial embedding model name is specified, create a separate container
# and load the model
embedding_config = config.embeddings_config()
embedding_model_name = embedding_config.get("embedding_model_name")
if embedding_model_name:
embedding_model_path = pathlib.Path(
unwrap(embedding_config.get("embedding_model_dir"), "models")
)
embedding_model_path = embedding_model_path / embedding_model_name
try:
await model.load_embedding_model(embedding_model_path, **embedding_config)
except ImportError as ex:
logger.error(ex.msg)
await start_api(host, port)
def entrypoint(arguments: Optional[dict] = None):
setup_logger()
# Set up signal aborting
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Load from YAML config
config.from_file(pathlib.Path("config.yml"))
# Parse and override config from args
if arguments is None:
parser = init_argparser()
arguments = convert_args_to_dict(parser.parse_args(), parser)
config.from_args(arguments)
if do_export_openapi:
openapi_json = export_openapi()
with open("openapi.json", "w") as f:
f.write(json.dumps(openapi_json))
logger.info("Successfully wrote OpenAPI spec to openapi.json")
return
developer_config = config.developer_config()
# Check exllamav2 version and give a descriptive error if it's too old
# Skip if launching unsafely
if unwrap(developer_config.get("unsafe_launch"), False):
logger.warning(
"UNSAFE: Skipping ExllamaV2 version check.\n"
"If you aren't a developer, please keep this off!"
)
else:
check_exllama_version()
# Enable CUDA malloc backend
if unwrap(developer_config.get("cuda_malloc_backend"), False):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
# Use Uvloop/Winloop
if unwrap(developer_config.get("uvloop"), False):
if platform.system() == "Windows":
from winloop import install
else:
from uvloop import install
# Set loop event policy
install()
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
# Set the process priority
if unwrap(developer_config.get("realtime_process_priority"), False):
import psutil
current_process = psutil.Process(os.getpid())
if platform.system() == "Windows":
current_process.nice(psutil.REALTIME_PRIORITY_CLASS)
else:
current_process.nice(psutil.IOPRIO_CLASS_RT)
logger.warning(
"EXPERIMENTAL: Process priority set to Realtime. \n"
"If you're not running on administrator/sudo, the priority is set to high."
)
# Enter into the async event loop
asyncio.run(entrypoint_async())
if __name__ == "__main__":
asyncio.run(entrypoint())
entrypoint()

View File

@@ -16,7 +16,7 @@ version = "0.0.1"
description = "An OAI compatible exllamav2 API that's both lightweight and fast"
requires-python = ">=3.10"
dependencies = [
"fastapi >= 0.110.0",
"fastapi-slim >= 0.110.0",
"pydantic >= 2.0.0",
"PyYAML",
"rich",
@@ -28,13 +28,21 @@ dependencies = [
"tokenizers",
"lm-format-enforcer >= 0.9.6",
"aiofiles",
"aiohttp",
"huggingface_hub",
"psutil",
"httptools>=0.5.0",
# Improved asyncio loops
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
"winloop ; platform_system == 'Windows'",
# TEMP: Remove once 2.x is fixed in upstream
"numpy < 2.0.0",
# TODO: Maybe move these to a downloader feature?
"aiohttp",
"huggingface_hub",
# For python 3.12
"fastparquet @ https://github.com/theroyallab/fastparquet/releases/download/v2024.5.0/fastparquet-0.1.dev837-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"setuptools ; python_version == '3.12'"
]
[project.urls]
@@ -43,7 +51,9 @@ dependencies = [
[project.optional-dependencies]
extras = [
# Heavy dependencies that aren't for everyday use
"outlines"
"outlines",
"infinity-emb",
"sentence-transformers",
]
dev = [
"ruff == 0.3.2"
@@ -58,22 +68,22 @@ cu121 = [
"torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
]
cu118 = [
# Torch
@@ -85,17 +95,17 @@ cu118 = [
"torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
]
amd = [
# Torch triton for ROCm
@@ -109,9 +119,9 @@ amd = [
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
]
# MARK: Ruff options

View File

@@ -19,3 +19,5 @@ if exist "%CONDA_PREFIX%" (
:: Call the python script with batch args
call python start.py %*
pause

184
start.py
View File

@@ -1,17 +1,21 @@
"""Utility to automatically upgrade and start the API"""
import asyncio
import argparse
import json
import os
import pathlib
import platform
import subprocess
import sys
from shutil import copyfile
import traceback
from common.args import convert_args_to_dict, init_argparser
start_options = {}
def get_user_choice(question: str, options_dict: dict):
"""
Gets user input in a commandline script.
@@ -40,36 +44,24 @@ def get_install_features(lib_name: str = None):
"""Fetches the appropriate requirements file depending on the GPU"""
install_features = None
possible_features = ["cu121", "cu118", "amd"]
saved_lib_path = pathlib.Path("gpu_lib.txt")
if lib_name:
print("Overriding GPU lib name from args.")
else:
# Try getting the GPU lib from file
if saved_lib_path.exists():
with open(saved_lib_path.resolve(), "r") as f:
lib_name = f.readline().strip()
else:
# Ask the user for the GPU lib
gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
"C": {"pretty": "AMD", "internal": "amd"},
}
user_input = get_user_choice(
"Select your GPU. If you don't know, select Cuda 12.x (A)",
gpu_lib_choices,
)
if not lib_name:
# Ask the user for the GPU lib
gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
"C": {"pretty": "AMD", "internal": "amd"},
}
user_input = get_user_choice(
"Select your GPU. If you don't know, select Cuda 12.x (A)",
gpu_lib_choices,
)
lib_name = gpu_lib_choices.get(user_input, {}).get("internal")
lib_name = gpu_lib_choices.get(user_input, {}).get("internal")
# Write to a file for subsequent runs
with open(saved_lib_path.resolve(), "w") as f:
f.write(lib_name)
print(
"Saving your choice to gpu_lib.txt. "
"Delete this file and restart if you want to change your selection."
)
# Write to start options
start_options["gpu_lib"] = lib_name
print("Saving your choice to start options.")
# Assume default if the file is invalid
if lib_name and lib_name in possible_features:
@@ -79,7 +71,7 @@ def get_install_features(lib_name: str = None):
print(
f"WARN: GPU library {lib_name} not found. "
"Skipping GPU-specific dependencies.\n"
"WARN: Please delete gpu_lib.txt and restart "
"WARN: Please remove the `gpu_lib` key from start_options.json and restart "
"if you want to change your selection."
)
return
@@ -105,10 +97,22 @@ def add_start_args(parser: argparse.ArgumentParser):
"""Add start script args to the provided parser"""
start_group = parser.add_argument_group("start")
start_group.add_argument(
"-iu",
"--ignore-upgrade",
"-ur",
"--update-repository",
action="store_true",
help="Ignore requirements upgrade",
help="Update local git repository to latest",
)
start_group.add_argument(
"-ud",
"--update-deps",
action="store_true",
help="Update all pip dependencies",
)
start_group.add_argument(
"-fr",
"--force-reinstall",
action="store_true",
help="Forces a reinstall of dependencies. Only works with --update-deps",
)
start_group.add_argument(
"-nw",
@@ -123,6 +127,26 @@ def add_start_args(parser: argparse.ArgumentParser):
)
def migrate_gpu_lib():
gpu_lib_path = pathlib.Path("gpu_lib.txt")
if not gpu_lib_path.exists():
return
print("Migrating gpu_lib.txt to the new start_options.json")
with open("gpu_lib.txt", "r") as gpu_lib_file:
start_options["gpu_lib"] = gpu_lib_file.readline().strip()
start_options["first_run_done"] = True
# Remove the old file
gpu_lib_path.unlink()
print(
"Successfully migrated gpu lib options to start_options. "
"The old file has been deleted."
)
if __name__ == "__main__":
subprocess.run(["pip", "-V"])
@@ -130,6 +154,35 @@ if __name__ == "__main__":
parser = init_argparser()
add_start_args(parser)
args = parser.parse_args()
script_ext = "bat" if platform.system() == "Windows" else "sh"
start_options_path = pathlib.Path("start_options.json")
if start_options_path.exists():
with open(start_options_path) as start_options_file:
start_options = json.load(start_options_file)
print("Loaded your saved preferences from `start_options.json`")
if start_options.get("first_run_done"):
first_run = False
else:
print(
"It looks like you're running TabbyAPI for the first time. "
"Getting things ready..."
)
# Migrate from old setting storage
migrate_gpu_lib()
# Set variables that rely on start options
first_run = not start_options.get("first_run_done")
if args.gpu_lib:
print("Overriding GPU lib name from args.")
gpu_lib = args.gpu_lib
elif "gpu_lib" in start_options:
gpu_lib = start_options.get("gpu_lib")
else:
gpu_lib = None
# Create a config if it doesn't exist
# This is not necessary to run TabbyAPI, but is new user proof
@@ -145,18 +198,65 @@ if __name__ == "__main__":
f"Created one at {str(config_path.resolve())}"
)
if args.ignore_upgrade:
print("Ignoring pip dependency upgrade due to user request.")
else:
install_features = None if args.nowheel else get_install_features(args.gpu_lib)
features = f"[{install_features}]" if install_features else ""
if args.update_repository:
print("Pulling latest changes from Github.")
pull_command = "git pull"
subprocess.run(pull_command.split(" "))
if first_run or args.update_deps:
install_command = ["pip", "install", "-U"]
# Force a reinstall of the updated dependency if needed
if args.force_reinstall:
install_command.append("--force-reinstall")
install_features = None if args.nowheel else get_install_features(gpu_lib)
features = f".[{install_features}]" if install_features else "."
install_command.append(features)
# pip install .[features]
install_command = f"pip install -U .{features}"
print(f"Running install command: {install_command}")
subprocess.run(install_command.split(" "))
print(f"Running install command: {' '.join(install_command)}")
subprocess.run(install_command)
print()
if args.update_deps:
print(
f"Dependencies updated. Please run TabbyAPI with `start.{script_ext}`. "
"Exiting."
)
sys.exit(0)
else:
print(
f"Dependencies installed. Update them with `update_deps.{script_ext}` "
"inside the `update_scripts` folder."
)
if first_run:
start_options["first_run_done"] = True
# Save start options
with open("start_options.json", "w") as start_file:
start_file.write(json.dumps(start_options))
print(
"Successfully wrote your start script options to `start_options.json`. \n"
"If something goes wrong, editing or deleting the file "
"will reinstall TabbyAPI as a first-time user."
)
# Import entrypoint after installing all requirements
from main import entrypoint
try:
from main import entrypoint
asyncio.run(entrypoint(convert_args_to_dict(args, parser)))
converted_args = convert_args_to_dict(args, parser)
print("Starting TabbyAPI...")
entrypoint(converted_args)
except (ModuleNotFoundError, ImportError):
traceback.print_exc()
print(
"\n"
"This error was raised because a package was not found.\n"
"Update your dependencies by running update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'}\n\n"
)

View File

@@ -0,0 +1,24 @@
@echo off
:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades
:: This is intended for users who want to start the API and have everything upgraded and installed
:: cd to the parent directory
cd "%~dp0.."
:: Don't create a venv if a conda environment is active
if exist "%CONDA_PREFIX%" (
echo It looks like you're in a conda environment. Skipping venv check.
) else (
if not exist "venv\" (
echo Venv doesn't exist! Please run start.bat instead.
exit 0
)
call .\venv\Scripts\activate.bat
)
:: Call the python script with batch args
call python start.py --update-deps %*
pause

19
update_scripts/update_deps.sh Executable file
View File

@@ -0,0 +1,19 @@
#!/bin/bash
cd "$(dirname "$0")/.." || exit
if [ -n "$CONDA_PREFIX" ]; then
echo "It looks like you're in a conda environment. Skipping venv check."
else
if [ ! -d "venv" ]; then
echo "Venv doesn't exist! Please run start.sh instead."
exit 0
fi
echo "Activating venv"
# shellcheck source=/dev/null
source venv/bin/activate
fi
python3 start.py --update-deps "$@"

View File

@@ -0,0 +1,24 @@
@echo off
:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades
:: This is intended for users who want to start the API and have everything upgraded and installed
:: cd to the parent directory
cd "%~dp0.."
:: Don't create a venv if a conda environment is active
if exist "%CONDA_PREFIX%" (
echo It looks like you're in a conda environment. Skipping venv check.
) else (
if not exist "venv\" (
echo Venv doesn't exist! Please run start.bat instead.
exit 0
)
call .\venv\Scripts\activate.bat
)
:: Call the python script with batch args
call python start.py --update-deps --update-repository %*
pause

View File

@@ -0,0 +1,19 @@
#!/bin/bash
cd "$(dirname "$0")/.." || exit
if [ -n "$CONDA_PREFIX" ]; then
echo "It looks like you're in a conda environment. Skipping venv check."
else
if [ ! -d "venv" ]; then
echo "Venv doesn't exist! Please run start.sh instead."
exit 0
fi
echo "Activating venv"
# shellcheck source=/dev/null
source venv/bin/activate
fi
python3 start.py --update-deps --update-repository "$@"