mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Merge branch 'main' into main
This commit is contained in:
35
.github/ISSUE_TEMPLATE/bug_report.md
vendored
35
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -1,35 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Report code related issues
|
||||
title: "[BUG]"
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Disclaimer:** Github Issues are **only** for code related bugs. If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj)
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Logs**
|
||||
If applicable, add logs and tracebacks to help explain your problem.
|
||||
|
||||
**System info** (Bugs without this information will go lower on our priority list!)
|
||||
- OS: [ex. Windows]
|
||||
- Python version: [ex. 3.11]
|
||||
- CUDA/ROCm version: [ex. 12.x]
|
||||
- Python version: [ex. 3.11]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
97
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
97
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
name: Bug report
|
||||
description: Report code related issues
|
||||
title: "[BUG]"
|
||||
labels: bug
|
||||
body:
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### Disclaimer:
|
||||
Github Issues are **only** for code related bugs.
|
||||
If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj)
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: OS
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: GPU Library
|
||||
description: Ex. CUDA, ROCm
|
||||
options:
|
||||
- CUDA 12.x
|
||||
- CUDA 11.8
|
||||
- AMD ROCm
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: Python version
|
||||
options:
|
||||
- '3.12'
|
||||
- '3.11'
|
||||
- '3.10'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Describe the bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Reproduction steps
|
||||
description: Walk us through how the bug occurred and how to make it happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: What was expected to happen?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Logs
|
||||
description: If applicable, add logs and tracebacks to help explain your problem.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add any other context about the problem here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Acknowledgements
|
||||
description: Before submitting this issue, please make sure you have completed the following checklist.
|
||||
options:
|
||||
- label: I have looked for similar issues before submitting this one.
|
||||
required: true
|
||||
- label: I have read the disclaimer, and this issue is related to a code bug. If I have a question, I will use the Discord server.
|
||||
required: true
|
||||
- label: I understand that the developers have lives and my issue will be answered when possible.
|
||||
required: true
|
||||
- label: I understand the developers of this program are human, and I will ask my questions politely.
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Thanks!
|
||||
Well-formatted issues improve TabbyAPI and make the development process smoother.
|
||||
26
.github/ISSUE_TEMPLATE/feature_request.md
vendored
26
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -1,26 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest a new idea
|
||||
title: "[REQUEST]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Why should this feature be added?**
|
||||
An explanation of why the feature should be added. Please be as specific as possible to help us understand the reasoning.
|
||||
|
||||
**Examples**
|
||||
Examples of the feature in action and its significance compared to not having the feature.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
69
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
Normal file
69
.github/ISSUE_TEMPLATE/feature_request.yaml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
name: Feature request
|
||||
description: Suggest a new idea
|
||||
title: "[REQUEST]"
|
||||
body:
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Problem
|
||||
description: Is the feature request related to a problem? If so, please describe.
|
||||
placeholder: A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Solution
|
||||
description: Describe the solution you'd like.
|
||||
placeholder: A clear and concise description of what you want to happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Alternatives
|
||||
description: What alternative options did you consider?
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Explanation
|
||||
description: Why should this feature be added?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Examples
|
||||
description: |
|
||||
Examples of the feature in action and its significance.
|
||||
|
||||
Not required, but will make your request easier to understand.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Anything else to add?
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Acknowledgements
|
||||
description: Before submitting this issue, please make sure you have completed the following checklist.
|
||||
options:
|
||||
- label: I have looked for similar requests before submitting this one.
|
||||
required: true
|
||||
- label: I understand that the developers have lives and my issue will be answered when possible.
|
||||
required: true
|
||||
- label: I understand the developers of this program are human, and I will make my requests politely.
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Thanks!
|
||||
Well-formatted issues improve TabbyAPI and make the development process smoother.
|
||||
11
.github/workflows/pages.yml
vendored
11
.github/workflows/pages.yml
vendored
@@ -47,12 +47,17 @@ jobs:
|
||||
run: |
|
||||
npm install @redocly/cli -g
|
||||
- name: Export OpenAPI docs
|
||||
run: EXPORT_OPENAPI=1 python main.py
|
||||
run: |
|
||||
EXPORT_OPENAPI=1 python main.py
|
||||
mv openapi.json openapi-oai.json
|
||||
EXPORT_OPENAPI=1 python main.py --api-servers kobold
|
||||
mv openapi.json openapi-kobold.json
|
||||
- name: Build and store Redocly site
|
||||
run: |
|
||||
redocly build-docs openapi.json
|
||||
mkdir static
|
||||
mv redoc-static.html static/index.html
|
||||
mkdir static/kobold
|
||||
redocly build-docs openapi-oai.json -o static/index.html
|
||||
redocly build-docs openapi-kobold.json -o static/kobold/index.html
|
||||
- name: Setup Pages
|
||||
uses: actions/configure-pages@v5
|
||||
- name: Upload artifact
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -200,5 +200,11 @@ sampler_overrides/*
|
||||
# Gpu lib preferences file
|
||||
gpu_lib.txt
|
||||
|
||||
# Start options file
|
||||
start_options.json
|
||||
|
||||
# OpenAPI JSON
|
||||
openapi.json
|
||||
|
||||
# Infinity-emb cache
|
||||
.infinity_cache/
|
||||
|
||||
50
README.md
50
README.md
@@ -32,20 +32,41 @@
|
||||
|
||||
A FastAPI based application that allows for generating text using an LLM (large language model) using the [Exllamav2 backend](https://github.com/turboderp/exllamav2)
|
||||
|
||||
TabbyAPI is also the official API backend server for ExllamaV2.
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This project is marked rolling release. There may be bugs and changes down the line. Please be aware that you might need to reinstall dependencies if needed.
|
||||
This project is marked as rolling release. There may be bugs and changes down the line. Please be aware that you might need to reinstall dependencies if needed.
|
||||
|
||||
TabbyAPI is a hobby project solely for a small amount of users. It is not meant to run on production servers. For that, please look at other backends that support those workloads.
|
||||
TabbyAPI is a hobby project made for a small amount of users. It is not meant to run on production servers. For that, please look at other solutions that support those workloads.
|
||||
|
||||
## Getting Started
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> This README is not for getting started. Please read the Wiki.
|
||||
> This README does not have instructions for setting up. Please read the Wiki.
|
||||
|
||||
Read the [Wiki](https://github.com/theroyallab/tabbyAPI/wiki/1.-Getting-Started) for more information. It contains user-facing documentation for installation, configuration, sampling, API usage, and so much more.
|
||||
|
||||
## Features
|
||||
|
||||
- OpenAI compatible API
|
||||
- Loading/unloading models
|
||||
- HuggingFace model downloading
|
||||
- Embedding model support
|
||||
- JSON schema + Regex + EBNF support
|
||||
- AI Horde support
|
||||
- Speculative decoding via draft models
|
||||
- Multi-lora with independent scaling (ex. a weight of 0.9)
|
||||
- Inbuilt proxy to override client request parameters/samplers
|
||||
- Flexible Jinja2 template engine for chat completions that conforms to HuggingFace
|
||||
- Concurrent inference with asyncio
|
||||
- Utilizes modern python paradigms
|
||||
- Continuous batching engine using paged attention
|
||||
- Fast classifer-free guidance
|
||||
|
||||
And much more. If something is missing here, PR it in!
|
||||
|
||||
## Supported Model Types
|
||||
|
||||
TabbyAPI uses Exllamav2 as a powerful and fast backend for model inference, loading, etc. Therefore, the following types of models are supported:
|
||||
@@ -58,18 +79,6 @@ TabbyAPI uses Exllamav2 as a powerful and fast backend for model inference, load
|
||||
|
||||
In addition, TabbyAPI supports parallel batching using paged attention for Nvidia Ampere GPUs and higher.
|
||||
|
||||
#### Alternative Loaders/Backends
|
||||
|
||||
If you want to use a different model type or quantization method than the ones listed above, here are some alternative backends with their own APIs:
|
||||
|
||||
- GGUF + GGML - [KoboldCPP](https://github.com/lostruins/KoboldCPP)
|
||||
|
||||
- Production ready + Many other quants + batching - [Aphrodite Engine](https://github.com/PygmalionAI/Aphrodite-engine)
|
||||
|
||||
- Production ready + batching - [VLLM](https://github.com/vllm-project/vllm)
|
||||
|
||||
- [Text Generation WebUI](https://github.com/oobabooga/text-generation-webui)
|
||||
|
||||
## Contributing
|
||||
|
||||
Use the template when creating issues or pull requests, otherwise the developers may not look at your post.
|
||||
@@ -84,6 +93,17 @@ If you have a Pull Request
|
||||
|
||||
- Describe the pull request in detail, what, and why you are changing something
|
||||
|
||||
## Acknowldgements
|
||||
|
||||
TabbyAPI would not exist without the work of other contributors and FOSS projects:
|
||||
|
||||
- [ExllamaV2](https://github.com/turboderp/exllamav2)
|
||||
- [Aphrodite Engine](https://github.com/PygmalionAI/Aphrodite-engine)
|
||||
- [infinity-emb](https://github.com/michaelfeil/infinity)
|
||||
- [FastAPI](https://github.com/fastapi/fastapi)
|
||||
- [Text Generation WebUI](https://github.com/oobabooga/text-generation-webui)
|
||||
- [SillyTavern](https://github.com/SillyTavern/SillyTavern)
|
||||
|
||||
## Developers and Permissions
|
||||
|
||||
Creators/Developers:
|
||||
|
||||
@@ -47,7 +47,7 @@ from common.templating import (
|
||||
TemplateLoadError,
|
||||
find_template_from_model,
|
||||
)
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
|
||||
|
||||
@@ -70,8 +70,9 @@ class ExllamaV2Container:
|
||||
cache_size: int = None
|
||||
cache_mode: str = "FP16"
|
||||
draft_cache_mode: str = "FP16"
|
||||
max_batch_size: int = 20
|
||||
max_batch_size: Optional[int] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
hf_config: Optional[HuggingFaceConfig] = None
|
||||
|
||||
# GPU split vars
|
||||
gpu_split: Optional[list] = None
|
||||
@@ -186,6 +187,9 @@ class ExllamaV2Container:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Then override the base_seq_len if present
|
||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||
if override_base_seq_len:
|
||||
@@ -213,6 +217,9 @@ class ExllamaV2Container:
|
||||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
||||
# Set max batch size to the config override
|
||||
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
|
||||
|
||||
# Check whether the user's configuration supports flash/paged attention
|
||||
# Also check if exl2 has disabled flash attention
|
||||
if (
|
||||
@@ -268,15 +275,8 @@ class ExllamaV2Container:
|
||||
else:
|
||||
self.cache_size = self.config.max_seq_len
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = (
|
||||
pathlib.Path(self.config.model_dir) / "generation_config.json"
|
||||
)
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
@@ -288,6 +288,11 @@ class ExllamaV2Container:
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
logger.info(
|
||||
@@ -405,6 +410,9 @@ class ExllamaV2Container:
|
||||
def get_model_path(self, is_draft: bool = False):
|
||||
"""Get the path for this model."""
|
||||
|
||||
if is_draft and not self.draft_config:
|
||||
return None
|
||||
|
||||
model_path = pathlib.Path(
|
||||
self.draft_config.model_dir if is_draft else self.config.model_dir
|
||||
)
|
||||
@@ -446,6 +454,11 @@ class ExllamaV2Container:
|
||||
|
||||
# Immediately abort all jobs if asked
|
||||
if skip_wait:
|
||||
logger.warning(
|
||||
"Immediately terminating all jobs. "
|
||||
"Clients will have their requests cancelled.\n"
|
||||
)
|
||||
|
||||
# Requires a copy to avoid errors during iteration
|
||||
jobs_copy = self.generator.jobs.copy()
|
||||
for job in jobs_copy.values():
|
||||
@@ -485,15 +498,7 @@ class ExllamaV2Container:
|
||||
yield value
|
||||
|
||||
# Create async generator
|
||||
self.generator = ExLlamaV2DynamicGeneratorAsync(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
paged=self.paged,
|
||||
)
|
||||
await self.create_generator()
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
@@ -642,6 +647,34 @@ class ExllamaV2Container:
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
async def create_generator(self):
|
||||
try:
|
||||
# Don't acquire locks unless a model is loaded
|
||||
if self.model_loaded:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Immediately cancel all jobs
|
||||
await self.wait_for_jobs(skip_wait=True)
|
||||
|
||||
# Create new generator
|
||||
self.generator = ExLlamaV2DynamicGeneratorAsync(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
paged=self.paged,
|
||||
)
|
||||
finally:
|
||||
# This means the generator is being recreated
|
||||
# The load lock is already released in the load function
|
||||
if self.model_loaded:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
def get_loras(self):
|
||||
"""Convenience function to get all loras."""
|
||||
|
||||
@@ -701,11 +734,15 @@ class ExllamaV2Container:
|
||||
Free all VRAM resources used by this model
|
||||
"""
|
||||
|
||||
try:
|
||||
await self.load_lock.acquire()
|
||||
# Shutdown immediately unloads and bypasses all locks
|
||||
do_shutdown = kwargs.get("shutdown")
|
||||
|
||||
# Wait for other jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
try:
|
||||
if not do_shutdown:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Wait for other jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
# Delete references held in the grammar module
|
||||
clear_grammar_func_cache()
|
||||
@@ -745,10 +782,11 @@ class ExllamaV2Container:
|
||||
|
||||
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
|
||||
finally:
|
||||
self.load_lock.release()
|
||||
if not do_shutdown:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string"""
|
||||
@@ -800,10 +838,14 @@ class ExllamaV2Container:
|
||||
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
async def generate(self, prompt: str, **kwargs):
|
||||
async def generate(
|
||||
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
|
||||
):
|
||||
"""Generate a response to a prompt"""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(prompt, **kwargs):
|
||||
async for generation in self.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
):
|
||||
generations.append(generation)
|
||||
|
||||
joined_generation = {
|
||||
@@ -853,7 +895,11 @@ class ExllamaV2Container:
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(
|
||||
self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
@@ -972,8 +1018,37 @@ class ExllamaV2Container:
|
||||
kwargs.get("repetition_decay"), fallback_decay, 0
|
||||
)
|
||||
|
||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
json_schema = unwrap(kwargs.get("json_schema"))
|
||||
if json_schema:
|
||||
grammar_handler.add_json_schema_filter(
|
||||
json_schema, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add regex filter if it exists
|
||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||
if regex_pattern:
|
||||
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
if grammar_string:
|
||||
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
|
||||
|
||||
# Set banned strings
|
||||
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
||||
if banned_strings and len(grammar_handler.filters) > 0:
|
||||
logger.warning(
|
||||
"Disabling banned_strings because "
|
||||
"they cannot be used with grammar filters."
|
||||
)
|
||||
|
||||
banned_strings = []
|
||||
|
||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||
logit_bias = kwargs.get("logit_bias")
|
||||
@@ -1021,26 +1096,6 @@ class ExllamaV2Container:
|
||||
"in the model's vocab. Skipping."
|
||||
)
|
||||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
json_schema = unwrap(kwargs.get("json_schema"))
|
||||
if json_schema:
|
||||
grammar_handler.add_json_schema_filter(
|
||||
json_schema, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add regex filter if it exists
|
||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||
if regex_pattern:
|
||||
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
if grammar_string:
|
||||
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
@@ -1085,34 +1140,11 @@ class ExllamaV2Container:
|
||||
# This is an inverse of skip_special_tokens
|
||||
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
|
||||
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
stream=kwargs.get("stream"),
|
||||
**gen_settings_log_dict,
|
||||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
generate_window=generate_window,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
skip_special_tokens=not decode_special_tokens,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
banned_tokens=banned_tokens,
|
||||
banned_strings=banned_strings,
|
||||
logit_bias=logit_bias,
|
||||
filters=grammar_handler.filters,
|
||||
)
|
||||
|
||||
# Log prompt to console
|
||||
log_prompt(prompt, negative_prompt)
|
||||
log_prompt(prompt, request_id, negative_prompt)
|
||||
|
||||
# Create and add a new job
|
||||
# Don't use the request ID here as there can be multiple jobs per request
|
||||
job_id = uuid.uuid4().hex
|
||||
job = ExLlamaV2DynamicJobAsync(
|
||||
self.generator,
|
||||
@@ -1138,6 +1170,7 @@ class ExllamaV2Container:
|
||||
max_seq_len = self.config.max_seq_len
|
||||
generated_tokens = 0
|
||||
full_response = ""
|
||||
metrics_result = {}
|
||||
|
||||
# Get the generation status once it's ready
|
||||
try:
|
||||
@@ -1191,23 +1224,15 @@ class ExllamaV2Container:
|
||||
|
||||
# Second yield if eos is true
|
||||
if result.get("eos"):
|
||||
log_response(full_response)
|
||||
log_response(request_id, full_response)
|
||||
|
||||
eos_reason = result.get("eos_reason")
|
||||
finish_reason = (
|
||||
"length" if eos_reason == "max_new_tokens" else "stop"
|
||||
)
|
||||
|
||||
log_metrics(
|
||||
result.get("time_enqueued"),
|
||||
result.get("prompt_tokens"),
|
||||
result.get("cached_tokens"),
|
||||
result.get("time_prefill"),
|
||||
result.get("new_tokens"),
|
||||
result.get("time_generate"),
|
||||
context_len,
|
||||
max_seq_len,
|
||||
)
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
|
||||
# Remove the token text
|
||||
generation = {
|
||||
@@ -1220,3 +1245,53 @@ class ExllamaV2Container:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
await job.cancel()
|
||||
except Exception as ex:
|
||||
# Create a new generator since the current state is broken
|
||||
# No need to wait for this to finish
|
||||
logger.error(
|
||||
"FATAL ERROR with generation. "
|
||||
"Attempting to recreate the generator. "
|
||||
"If this fails, please restart the server.\n"
|
||||
)
|
||||
asyncio.ensure_future(self.create_generator())
|
||||
|
||||
raise ex
|
||||
finally:
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
request_id=request_id,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
stream=kwargs.get("stream"),
|
||||
**gen_settings_log_dict,
|
||||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
generate_window=generate_window,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
skip_special_tokens=not decode_special_tokens,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
banned_tokens=banned_tokens,
|
||||
banned_strings=banned_strings,
|
||||
logit_bias=logit_bias,
|
||||
filters=grammar_handler.filters,
|
||||
)
|
||||
|
||||
# Log the metrics if present
|
||||
if metrics_result:
|
||||
log_metrics(
|
||||
request_id,
|
||||
metrics_result.get("time_enqueued"),
|
||||
metrics_result.get("prompt_tokens"),
|
||||
metrics_result.get("cached_tokens"),
|
||||
metrics_result.get("time_prefill"),
|
||||
metrics_result.get("new_tokens"),
|
||||
metrics_result.get("time_generate"),
|
||||
context_len,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
import platform
|
||||
import torch
|
||||
from packaging import version
|
||||
from importlib.metadata import PackageNotFoundError, version as package_version
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
|
||||
def check_exllama_version():
|
||||
"""Verifies the exllama version"""
|
||||
|
||||
required_version = version.parse("0.1.6")
|
||||
required_version = version.parse("0.1.7")
|
||||
current_version = version.parse(package_version("exllamav2").split("+")[0])
|
||||
|
||||
unsupported_message = (
|
||||
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
|
||||
f"or greater. Your current version is {current_version}.\n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Please update your environment by running an update script "
|
||||
"(update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
@@ -71,8 +73,9 @@ def supports_paged_attn():
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Please upgrade your environment by running an update script "
|
||||
"(update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
|
||||
66
backends/infinity/model.py
Normal file
66
backends/infinity/model.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import gc
|
||||
import pathlib
|
||||
import torch
|
||||
from loguru import logger
|
||||
from typing import List, Optional
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
# Conditionally import infinity to sidestep its logger
|
||||
# TODO: Make this prettier
|
||||
try:
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
|
||||
has_infinity_emb = True
|
||||
except ImportError:
|
||||
has_infinity_emb = False
|
||||
|
||||
|
||||
class InfinityContainer:
|
||||
model_dir: pathlib.Path
|
||||
model_is_loading: bool = False
|
||||
model_loaded: bool = False
|
||||
|
||||
# Conditionally set the type hint based on importablity
|
||||
# TODO: Clean this up
|
||||
if has_infinity_emb:
|
||||
engine: Optional[AsyncEmbeddingEngine] = None
|
||||
else:
|
||||
engine = None
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path):
|
||||
self.model_dir = model_directory
|
||||
|
||||
async def load(self, **kwargs):
|
||||
self.model_is_loading = True
|
||||
|
||||
# Use cpu by default
|
||||
device = unwrap(kwargs.get("embeddings_device"), "cpu")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model_name_or_path=str(self.model_dir),
|
||||
engine="torch",
|
||||
device=device,
|
||||
bettertransformer=False,
|
||||
model_warmup=False,
|
||||
)
|
||||
|
||||
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
||||
await self.engine.astart()
|
||||
|
||||
self.model_loaded = True
|
||||
logger.info("Embedding model successfully loaded.")
|
||||
|
||||
async def unload(self):
|
||||
await self.engine.astop()
|
||||
self.engine = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Embedding model unloaded.")
|
||||
|
||||
async def generate(self, sentence_input: List[str]):
|
||||
result_embeddings, usage = await self.engine.embed(sentence_input)
|
||||
|
||||
return {"embeddings": result_embeddings, "usage": usage}
|
||||
122
common/args.py
122
common/args.py
@@ -17,13 +17,16 @@ def init_argparser():
|
||||
"""Creates an argument parser that any function can use"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog="These args are only for a subset of the config. "
|
||||
+ "Please edit config.yml for all options!"
|
||||
epilog="NOTE: These args serve to override parts of the config. "
|
||||
+ "It's highly recommended to edit config.yml for all options and "
|
||||
+ "better descriptions!"
|
||||
)
|
||||
add_network_args(parser)
|
||||
add_model_args(parser)
|
||||
add_embeddings_args(parser)
|
||||
add_logging_args(parser)
|
||||
add_developer_args(parser)
|
||||
add_sampling_args(parser)
|
||||
add_config_args(parser)
|
||||
|
||||
return parser
|
||||
@@ -64,6 +67,17 @@ def add_network_args(parser: argparse.ArgumentParser):
|
||||
type=str_to_bool,
|
||||
help="Disable HTTP token authenticaion with requests",
|
||||
)
|
||||
network_group.add_argument(
|
||||
"--send-tracebacks",
|
||||
type=str_to_bool,
|
||||
help="Decide whether to send error tracebacks over the API",
|
||||
)
|
||||
network_group.add_argument(
|
||||
"--api-servers",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="API servers to enable. Options: (OAI, Kobold)",
|
||||
)
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser):
|
||||
@@ -74,6 +88,17 @@ def add_model_args(parser: argparse.ArgumentParser):
|
||||
"--model-dir", type=str, help="Overrides the directory to look for models"
|
||||
)
|
||||
model_group.add_argument("--model-name", type=str, help="An initial model to load")
|
||||
model_group.add_argument(
|
||||
"--use-dummy-models",
|
||||
type=str_to_bool,
|
||||
help="Add dummy OAI model names for API queries",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--use-as-default",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Names of args to use as a default fallback for API load requests ",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-seq-len", type=int, help="Override the maximum model sequence length"
|
||||
)
|
||||
@@ -82,25 +107,17 @@ def add_model_args(parser: argparse.ArgumentParser):
|
||||
type=str_to_bool,
|
||||
help="Overrides base model context length",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--cache-size",
|
||||
type=int,
|
||||
help="The size of the prompt cache (in number of tokens) to allocate",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
|
||||
)
|
||||
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
|
||||
model_group.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
help="Set the prompt template for chat completions",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split-auto",
|
||||
type=str_to_bool,
|
||||
help="Automatically allocate resources to GPUs",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autosplit-reserve",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Reserve VRAM used for autosplit loading (in MBs) ",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split",
|
||||
type=float,
|
||||
@@ -108,15 +125,44 @@ def add_model_args(parser: argparse.ArgumentParser):
|
||||
help="An integer array of GBs of vram to split between GPUs. "
|
||||
+ "Ignored if gpu_split_auto is true",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
|
||||
)
|
||||
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
|
||||
model_group.add_argument(
|
||||
"--cache-mode",
|
||||
type=str,
|
||||
help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--cache-size",
|
||||
type=int,
|
||||
help="The size of the prompt cache (in number of tokens) to allocate",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
help="Chunk size for prompt ingestion",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-batch-size",
|
||||
type=int,
|
||||
help="Maximum amount of prompts to process at one time",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
help="Set the jinja2 prompt template for chat completions",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--num-experts-per-token",
|
||||
type=int,
|
||||
help="Number of experts to use per token in MoE models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--use-cfg",
|
||||
"--fasttensors",
|
||||
type=str_to_bool,
|
||||
help="Enables CFG support",
|
||||
help="Possibly increases model loading speeds",
|
||||
)
|
||||
|
||||
|
||||
@@ -132,6 +178,11 @@ def add_logging_args(parser: argparse.ArgumentParser):
|
||||
type=str_to_bool,
|
||||
help="Enable generation parameter logging",
|
||||
)
|
||||
logging_group.add_argument(
|
||||
"--log-requests",
|
||||
type=str_to_bool,
|
||||
help="Enable request logging",
|
||||
)
|
||||
|
||||
|
||||
def add_developer_args(parser: argparse.ArgumentParser):
|
||||
@@ -149,5 +200,38 @@ def add_developer_args(parser: argparse.ArgumentParser):
|
||||
developer_group.add_argument(
|
||||
"--cuda-malloc-backend",
|
||||
type=str_to_bool,
|
||||
help="Disables API request streaming",
|
||||
help="Runs with the pytorch CUDA malloc backend",
|
||||
)
|
||||
developer_group.add_argument(
|
||||
"--uvloop",
|
||||
type=str_to_bool,
|
||||
help="Run asyncio using Uvloop or Winloop",
|
||||
)
|
||||
|
||||
|
||||
def add_sampling_args(parser: argparse.ArgumentParser):
|
||||
"""Adds sampling-specific arguments"""
|
||||
|
||||
sampling_group = parser.add_argument_group("sampling")
|
||||
sampling_group.add_argument(
|
||||
"--override-preset", type=str, help="Select a sampler override preset"
|
||||
)
|
||||
|
||||
|
||||
def add_embeddings_args(parser: argparse.ArgumentParser):
|
||||
"""Adds arguments specific to embeddings"""
|
||||
|
||||
embeddings_group = parser.add_argument_group("embeddings")
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-dir",
|
||||
type=str,
|
||||
help="Overrides the directory to look for models",
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-name", type=str, help="An initial model to load"
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embeddings-device",
|
||||
type=str,
|
||||
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
|
||||
)
|
||||
|
||||
@@ -5,11 +5,13 @@ application, it should be fine.
|
||||
|
||||
import secrets
|
||||
import yaml
|
||||
from fastapi import Header, HTTPException
|
||||
from fastapi import Header, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from common.utils import coalesce
|
||||
|
||||
|
||||
class AuthKeys(BaseModel):
|
||||
"""
|
||||
@@ -75,7 +77,27 @@ def load_auth_keys(disable_from_config: bool):
|
||||
)
|
||||
|
||||
|
||||
async def validate_key_permission(test_key: str):
|
||||
def get_key_permission(request: Request):
|
||||
"""
|
||||
Gets the key permission from a request.
|
||||
|
||||
Internal only! Use the depends functions for incoming requests.
|
||||
"""
|
||||
|
||||
# Give full admin permissions if auth is disabled
|
||||
if DISABLE_AUTH:
|
||||
return "admin"
|
||||
|
||||
# Hyphens are okay here
|
||||
test_key = coalesce(
|
||||
request.headers.get("x-admin-key"),
|
||||
request.headers.get("x-api-key"),
|
||||
request.headers.get("authorization"),
|
||||
)
|
||||
|
||||
if test_key is None:
|
||||
raise ValueError("The provided authentication key is missing.")
|
||||
|
||||
if test_key.lower().startswith("bearer"):
|
||||
test_key = test_key.split(" ")[1]
|
||||
|
||||
|
||||
@@ -46,15 +46,12 @@ def from_args(args: dict):
|
||||
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
|
||||
|
||||
# Generation Logging config
|
||||
gen_logging_override = args.get("logging")
|
||||
if gen_logging_override:
|
||||
cur_gen_logging_config = gen_logging_config()
|
||||
logging_override = args.get("logging")
|
||||
if logging_override:
|
||||
cur_logging_config = logging_config()
|
||||
GLOBAL_CONFIG["logging"] = {
|
||||
**cur_gen_logging_config,
|
||||
**{
|
||||
k.replace("log_", ""): gen_logging_override[k]
|
||||
for k in gen_logging_override
|
||||
},
|
||||
**cur_logging_config,
|
||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
||||
}
|
||||
|
||||
developer_override = args.get("developer")
|
||||
@@ -62,6 +59,11 @@ def from_args(args: dict):
|
||||
cur_developer_config = developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
||||
|
||||
embeddings_override = args.get("embeddings")
|
||||
if embeddings_override:
|
||||
cur_embeddings_config = embeddings_config()
|
||||
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
|
||||
|
||||
|
||||
def sampling_config():
|
||||
"""Returns the sampling parameter config from the global config"""
|
||||
@@ -90,11 +92,16 @@ def network_config():
|
||||
return unwrap(GLOBAL_CONFIG.get("network"), {})
|
||||
|
||||
|
||||
def gen_logging_config():
|
||||
"""Returns the generation logging config from the global config"""
|
||||
def logging_config():
|
||||
"""Returns the logging config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
||||
|
||||
|
||||
def developer_config():
|
||||
"""Returns the developer specific config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
||||
|
||||
|
||||
def embeddings_config():
|
||||
"""Returns the embeddings config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})
|
||||
|
||||
@@ -101,6 +101,7 @@ async def hf_repo_download(
|
||||
chunk_limit: Optional[float],
|
||||
include: Optional[List[str]],
|
||||
exclude: Optional[List[str]],
|
||||
timeout: Optional[int],
|
||||
repo_type: Optional[str] = "model",
|
||||
):
|
||||
"""Gets a repo's information from HuggingFace and downloads it locally."""
|
||||
@@ -145,7 +146,8 @@ async def hf_repo_download(
|
||||
logger.info(f"Saving {repo_id} to {str(download_path)}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout) # Turn off timeout
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
tasks = []
|
||||
logger.info(f"Starting download for {repo_id}")
|
||||
|
||||
|
||||
@@ -51,25 +51,31 @@ def log_generation_params(**kwargs):
|
||||
logger.info(f"Generation options: {kwargs}\n")
|
||||
|
||||
|
||||
def log_prompt(prompt: str, negative_prompt: Optional[str]):
|
||||
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
|
||||
"""Logs the prompt to console."""
|
||||
if PREFERENCES.prompt:
|
||||
formatted_prompt = "\n" + prompt
|
||||
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
|
||||
logger.info(
|
||||
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
|
||||
)
|
||||
|
||||
if negative_prompt:
|
||||
formatted_negative_prompt = "\n" + negative_prompt
|
||||
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")
|
||||
|
||||
|
||||
def log_response(response: str):
|
||||
def log_response(request_id: str, response: str):
|
||||
"""Logs the response to console."""
|
||||
if PREFERENCES.prompt:
|
||||
formatted_response = "\n" + response
|
||||
logger.info(f"Response: {formatted_response if response else 'Empty'}\n")
|
||||
logger.info(
|
||||
f"Response (ID: {request_id}): "
|
||||
f"{formatted_response if response else 'Empty'}\n"
|
||||
)
|
||||
|
||||
|
||||
def log_metrics(
|
||||
request_id: str,
|
||||
queue_time: float,
|
||||
prompt_tokens: int,
|
||||
cached_tokens: int,
|
||||
@@ -80,7 +86,7 @@ def log_metrics(
|
||||
max_seq_len: int,
|
||||
):
|
||||
initial_response = (
|
||||
f"Metrics: {generated_tokens} tokens generated in "
|
||||
f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in "
|
||||
f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
|
||||
)
|
||||
itemization = []
|
||||
|
||||
100
common/model.py
100
common/model.py
@@ -5,11 +5,14 @@ Containers exist as a common interface for backends.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
from enum import Enum
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from common import config
|
||||
from common.logger import get_loading_progress_bar
|
||||
from common.networking import handle_request_error
|
||||
from common.utils import unwrap
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
||||
@@ -18,6 +21,21 @@ if not do_export_openapi:
|
||||
|
||||
# Global model container
|
||||
container: Optional[ExllamaV2Container] = None
|
||||
embeddings_container = None
|
||||
|
||||
# Type hint the infinity emb container if it exists
|
||||
from backends.infinity.model import has_infinity_emb
|
||||
|
||||
if has_infinity_emb:
|
||||
from backends.infinity.model import InfinityContainer
|
||||
|
||||
embeddings_container: Optional[InfinityContainer] = None
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
MODEL = "model"
|
||||
DRAFT = "draft"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
@@ -25,11 +43,11 @@ def load_progress(module, modules):
|
||||
yield module, modules
|
||||
|
||||
|
||||
async def unload_model(skip_wait: bool = False):
|
||||
async def unload_model(skip_wait: bool = False, shutdown: bool = False):
|
||||
"""Unloads a model"""
|
||||
global container
|
||||
|
||||
await container.unload(skip_wait=skip_wait)
|
||||
await container.unload(skip_wait=skip_wait, shutdown=shutdown)
|
||||
container = None
|
||||
|
||||
|
||||
@@ -46,8 +64,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
# Unload the existing model
|
||||
if container and container.model:
|
||||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
@@ -98,17 +114,89 @@ async def unload_loras():
|
||||
await container.unload(loras_only=True)
|
||||
|
||||
|
||||
def get_config_default(key, fallback=None, is_draft=False):
|
||||
async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
||||
global embeddings_container
|
||||
|
||||
# Break out if infinity isn't installed
|
||||
if not has_infinity_emb:
|
||||
raise ImportError(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
|
||||
# Check if the model is already loaded
|
||||
if embeddings_container and embeddings_container.engine:
|
||||
loaded_model_name = embeddings_container.model_dir.name
|
||||
|
||||
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
|
||||
raise ValueError(
|
||||
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
logger.info("Unloading existing embeddings model.")
|
||||
await unload_embedding_model()
|
||||
|
||||
embeddings_container = InfinityContainer(model_path)
|
||||
await embeddings_container.load(**kwargs)
|
||||
|
||||
|
||||
async def unload_embedding_model():
|
||||
global embeddings_container
|
||||
|
||||
await embeddings_container.unload()
|
||||
embeddings_container = None
|
||||
|
||||
|
||||
def get_config_default(key: str, fallback=None, model_type: str = "model"):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
model_config = config.model_config()
|
||||
default_keys = unwrap(model_config.get("use_as_default"), [])
|
||||
|
||||
# Add extra keys to defaults
|
||||
default_keys.append("embeddings_device")
|
||||
|
||||
if key in default_keys:
|
||||
# Is this a draft model load parameter?
|
||||
if is_draft:
|
||||
if model_type == "draft":
|
||||
draft_config = config.draft_model_config()
|
||||
return unwrap(draft_config.get(key), fallback)
|
||||
elif model_type == "embedding":
|
||||
embeddings_config = config.embeddings_config()
|
||||
return unwrap(embeddings_config.get(key), fallback)
|
||||
else:
|
||||
return unwrap(model_config.get(key), fallback)
|
||||
else:
|
||||
return fallback
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
||||
|
||||
if container is None or not (container.model_is_loading or container.model_loaded):
|
||||
error_message = handle_request_error(
|
||||
"No models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
async def check_embeddings_container():
|
||||
"""
|
||||
FastAPI depends that checks if an embeddings model is loaded.
|
||||
|
||||
This is the same as the model container check, but with embeddings instead.
|
||||
"""
|
||||
|
||||
if embeddings_container is None or not (
|
||||
embeddings_container.model_is_loading or embeddings_container.model_loaded
|
||||
):
|
||||
error_message = handle_request_error(
|
||||
"No embedding models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
"""Common utility functions"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import socket
|
||||
import traceback
|
||||
from fastapi import Request
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from common import config
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class TabbyRequestErrorMessage(BaseModel):
|
||||
@@ -33,15 +38,18 @@ def get_generator_error(message: str, exc_info: bool = True):
|
||||
def handle_request_error(message: str, exc_info: bool = True):
|
||||
"""Log a request error to the console."""
|
||||
|
||||
trace = traceback.format_exc()
|
||||
send_trace = unwrap(config.network_config().get("send_tracebacks"), False)
|
||||
|
||||
error_message = TabbyRequestErrorMessage(
|
||||
message=message, trace=traceback.format_exc()
|
||||
message=message, trace=trace if send_trace else None
|
||||
)
|
||||
|
||||
request_error = TabbyRequestError(error=error_message)
|
||||
|
||||
# Log the error and provided message to the console
|
||||
if error_message.trace and exc_info:
|
||||
logger.error(error_message.trace)
|
||||
if trace and exc_info:
|
||||
logger.error(trace)
|
||||
|
||||
logger.error(f"Sent to request: {message}")
|
||||
|
||||
@@ -78,8 +86,9 @@ async def run_with_request_disconnect(
|
||||
|
||||
try:
|
||||
return call_task.result()
|
||||
except (asyncio.CancelledError, asyncio.InvalidStateError):
|
||||
except (asyncio.CancelledError, asyncio.InvalidStateError) as ex:
|
||||
handle_request_disconnect(disconnect_message)
|
||||
raise HTTPException(422, disconnect_message) from ex
|
||||
|
||||
|
||||
def is_port_in_use(port: int) -> bool:
|
||||
@@ -93,3 +102,39 @@ def is_port_in_use(port: int) -> bool:
|
||||
test_socket.settimeout(1)
|
||||
with test_socket:
|
||||
return test_socket.connect_ex(("localhost", port)) == 0
|
||||
|
||||
|
||||
async def add_request_id(request: Request):
|
||||
"""FastAPI depends to add a UUID to a request's state."""
|
||||
|
||||
request.state.id = uuid4().hex
|
||||
return request
|
||||
|
||||
|
||||
async def log_request(request: Request):
|
||||
"""FastAPI depends to log a request to the user."""
|
||||
|
||||
log_message = [f"Information for {request.method} request {request.state.id}:"]
|
||||
|
||||
log_message.append(f"URL: {request.url}")
|
||||
log_message.append(f"Headers: {dict(request.headers)}")
|
||||
|
||||
if request.method != "GET":
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
body = json.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
log_message.append(f"Body: {dict(body)}")
|
||||
|
||||
logger.info("\n".join(log_message))
|
||||
|
||||
|
||||
def get_global_depends():
|
||||
"""Returns global dependencies for a FastAPI app."""
|
||||
|
||||
depends = [Depends(add_request_id)]
|
||||
|
||||
if config.logging_config().get("requests"):
|
||||
depends.append(Depends(log_request))
|
||||
|
||||
return depends
|
||||
|
||||
@@ -16,11 +16,15 @@ class BaseSamplerRequest(BaseModel):
|
||||
|
||||
max_tokens: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
||||
validation_alias=AliasChoices("max_tokens", "max_length"),
|
||||
description="Aliases: max_length",
|
||||
examples=[150],
|
||||
)
|
||||
|
||||
min_tokens: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
|
||||
validation_alias=AliasChoices("min_tokens", "min_length"),
|
||||
description="Aliases: min_length",
|
||||
examples=[0],
|
||||
)
|
||||
|
||||
@@ -30,13 +34,22 @@ class BaseSamplerRequest(BaseModel):
|
||||
)
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("stop", [])
|
||||
default_factory=lambda: get_default_sampler_value("stop", []),
|
||||
validation_alias=AliasChoices("stop", "stop_sequence"),
|
||||
description="Aliases: stop_sequence",
|
||||
)
|
||||
|
||||
banned_strings: Optional[Union[str, List[str]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_strings", [])
|
||||
)
|
||||
|
||||
banned_tokens: Optional[Union[List[int], str]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
|
||||
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
|
||||
description="Aliases: custom_token_bans",
|
||||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
token_healing: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||
)
|
||||
@@ -76,6 +89,13 @@ class BaseSamplerRequest(BaseModel):
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
typical: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
skew: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("skew", 0.0),
|
||||
examples=[0.0],
|
||||
@@ -91,9 +111,24 @@ class BaseSamplerRequest(BaseModel):
|
||||
|
||||
repetition_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
validation_alias=AliasChoices("repetition_penalty", "rep_pen"),
|
||||
description="Aliases: rep_pen",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
||||
validation_alias=AliasChoices(
|
||||
"penalty_range",
|
||||
"repetition_range",
|
||||
"repetition_penalty_range",
|
||||
"rep_pen_range",
|
||||
),
|
||||
description=(
|
||||
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range"
|
||||
),
|
||||
)
|
||||
|
||||
repetition_decay: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
|
||||
)
|
||||
@@ -118,6 +153,8 @@ class BaseSamplerRequest(BaseModel):
|
||||
|
||||
ban_eos_token: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
|
||||
validation_alias=AliasChoices("ban_eos_token", "ignore_eos"),
|
||||
description="Aliases: ignore_eos",
|
||||
examples=[False],
|
||||
)
|
||||
|
||||
@@ -151,24 +188,6 @@ class BaseSamplerRequest(BaseModel):
|
||||
default_factory=lambda: get_default_sampler_value("speculative_ngram"),
|
||||
)
|
||||
|
||||
# Aliased variables
|
||||
typical: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
||||
validation_alias=AliasChoices(
|
||||
"penalty_range",
|
||||
"repetition_range",
|
||||
"repetition_penalty_range",
|
||||
),
|
||||
description="Aliases: repetition_range, repetition_penalty_range",
|
||||
)
|
||||
|
||||
cfg_scale: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
|
||||
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
|
||||
@@ -196,13 +215,6 @@ class BaseSamplerRequest(BaseModel):
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
banned_tokens: Optional[Union[List[int], str]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
|
||||
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
|
||||
description="Aliases: custom_token_bans",
|
||||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
# TODO: Return back to adaptable class-based validation But that's just too much
|
||||
# abstraction compared to simple if statements at the moment
|
||||
def validate_params(self):
|
||||
|
||||
@@ -1,13 +1,40 @@
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
from loguru import logger
|
||||
from types import FrameType
|
||||
|
||||
from common import model
|
||||
|
||||
|
||||
SHUTTING_DOWN: bool = False
|
||||
|
||||
|
||||
def signal_handler(*_):
|
||||
"""Signal handler for main function. Run before uvicorn starts."""
|
||||
|
||||
global SHUTTING_DOWN
|
||||
|
||||
if SHUTTING_DOWN:
|
||||
return
|
||||
|
||||
logger.warning("Shutdown signal called. Exiting gracefully.")
|
||||
SHUTTING_DOWN = True
|
||||
|
||||
# Run async unloads for model
|
||||
asyncio.ensure_future(signal_handler_async())
|
||||
|
||||
|
||||
async def signal_handler_async(*_):
|
||||
"""Internal signal handler. Runs all async code to shut down the program."""
|
||||
|
||||
if model.container:
|
||||
await model.unload_model(skip_wait=True, shutdown=True)
|
||||
|
||||
if model.embeddings_container:
|
||||
await model.unload_embedding_model()
|
||||
|
||||
# Exit the program
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class PromptTemplate:
|
||||
raise ImportError(
|
||||
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
||||
f"or greater. Current version: {package_version('jinja2')}\n"
|
||||
"Please upgrade jinja by running the following command: "
|
||||
"Please update jinja by running the following command: "
|
||||
"pip install --upgrade jinja2"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import pathlib
|
||||
from typing import List, Optional, Union
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -11,6 +12,7 @@ class GenerationConfig(BaseModel):
|
||||
"""
|
||||
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None
|
||||
bad_words_ids: Optional[List[List[int]]] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
@@ -30,3 +32,38 @@ class GenerationConfig(BaseModel):
|
||||
return [self.eos_token_id]
|
||||
else:
|
||||
return self.eos_token_id
|
||||
|
||||
|
||||
class HuggingFaceConfig(BaseModel):
|
||||
"""
|
||||
An abridged version of HuggingFace's model config.
|
||||
Will be expanded as needed.
|
||||
"""
|
||||
|
||||
badwordsids: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
|
||||
hf_config_dict = json.load(hf_config_json)
|
||||
return self.model_validate(hf_config_dict)
|
||||
|
||||
def get_badwordsids(self):
|
||||
"""Wrapper method to fetch badwordsids."""
|
||||
|
||||
if self.badwordsids:
|
||||
try:
|
||||
bad_words_list = json.loads(self.badwordsids)
|
||||
return bad_words_list
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Skipping badwordsids from config.json "
|
||||
"since it's not a valid array."
|
||||
)
|
||||
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
@@ -18,3 +18,9 @@ def prune_dict(input_dict):
|
||||
"""Trim out instances of None from a dictionary."""
|
||||
|
||||
return {k: v for k, v in input_dict.items() if v is not None}
|
||||
|
||||
|
||||
def flat_map(input_list):
|
||||
"""Flattens a list of lists into a single list."""
|
||||
|
||||
return [item for sublist in input_list for item in sublist]
|
||||
|
||||
@@ -19,6 +19,14 @@ network:
|
||||
# Turn on this option if you are ONLY connecting from localhost
|
||||
disable_auth: False
|
||||
|
||||
# Send tracebacks over the API to clients (default: False)
|
||||
# NOTE: Only enable this for debug purposes
|
||||
send_tracebacks: False
|
||||
|
||||
# Select API servers to enable (default: ["OAI"])
|
||||
# Possible values: OAI
|
||||
api_servers: ["OAI"]
|
||||
|
||||
# Options for logging
|
||||
logging:
|
||||
# Enable prompt logging (default: False)
|
||||
@@ -27,6 +35,10 @@ logging:
|
||||
# Enable generation parameter logging (default: False)
|
||||
generation_params: False
|
||||
|
||||
# Enable request logging (default: False)
|
||||
# NOTE: Only use this for debugging!
|
||||
requests: False
|
||||
|
||||
# Options for sampling
|
||||
sampling:
|
||||
# Override preset name. Find this in the sampler-overrides folder (default: None)
|
||||
@@ -50,6 +62,16 @@ developer:
|
||||
# This can save a few MBs of VRAM, but has a risk of errors. Use at your own risk.
|
||||
#cuda_malloc_backend: False
|
||||
|
||||
# Enable Uvloop or Winloop (default: False)
|
||||
# Make the program utilize a faster async event loop which can improve performance
|
||||
# NOTE: It's recommended to enable this, but if something breaks, turn this off.
|
||||
#uvloop: False
|
||||
|
||||
# Set process to use a higher priority
|
||||
# For realtime process priority, run as administrator or sudo
|
||||
# Otherwise, the priority will be set to high
|
||||
#realtime_process_priority: False
|
||||
|
||||
# Options for model overrides and loading
|
||||
# Please read the comments to understand how arguments are handled between initial and API loads
|
||||
model:
|
||||
@@ -124,11 +146,11 @@ model:
|
||||
# NOTE: Effects vary depending on the model. An ideal value is between 512 and 4096
|
||||
#chunk_size: 2048
|
||||
|
||||
# Set the maximum amount of prompts to process at one time (batch)
|
||||
# This will be automatically adjusted depending on the cache size.
|
||||
# Set the maximum amount of prompts to process at one time (default: None/Automatic)
|
||||
# This will be automatically calculated if left blank.
|
||||
# A max batch size of 1 processes prompts one at a time.
|
||||
# NOTE: Only available for Nvidia ampere (30 series) and above GPUs
|
||||
#max_batch_size: 20
|
||||
#max_batch_size:
|
||||
|
||||
# Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None)
|
||||
# If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name
|
||||
@@ -179,3 +201,22 @@ model:
|
||||
#loras:
|
||||
#- name: lora1
|
||||
# scaling: 1.0
|
||||
|
||||
# Options for embedding models and loading.
|
||||
# NOTE: Embeddings requires the "extras" feature to be installed
|
||||
# Install it via "pip install .[extras]"
|
||||
embeddings:
|
||||
# Overrides directory to look for embedding models (default: models)
|
||||
embedding_model_dir: models
|
||||
|
||||
# Device to load embedding models on (default: cpu)
|
||||
# Possible values: cpu, auto, cuda
|
||||
# NOTE: It's recommended to load embedding models on the CPU.
|
||||
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
|
||||
embeddings_device: cpu
|
||||
|
||||
# The below parameters only apply for initial loads
|
||||
# All API based loads do NOT inherit these settings unless specified in use_as_default
|
||||
|
||||
# An initial embedding model to load on the infinity backend (default: None)
|
||||
embedding_model_name:
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- NAME=TabbyAPI
|
||||
- NVIDIA_VISIBLE_DEVICES=all
|
||||
volumes:
|
||||
- ./models:/usr/src/app/models
|
||||
- ./models:/app/models
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
|
||||
161
endpoints/Kobold/router.py
Normal file
161
endpoints/Kobold/router.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from sys import maxsize
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from common import model
|
||||
from common.auth import check_api_key
|
||||
from common.model import check_model_container
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.utils.model import get_current_model
|
||||
from endpoints.Kobold.types.generation import (
|
||||
AbortRequest,
|
||||
AbortResponse,
|
||||
CheckGenerateRequest,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
)
|
||||
from endpoints.Kobold.types.model import CurrentModelResponse, MaxLengthResponse
|
||||
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
|
||||
from endpoints.Kobold.utils.generation import (
|
||||
abort_generation,
|
||||
generation_status,
|
||||
get_generation,
|
||||
stream_generation,
|
||||
)
|
||||
|
||||
|
||||
api_name = "KoboldAI"
|
||||
router = APIRouter(prefix="/api")
|
||||
urls = {
|
||||
"Generation": "http://{host}:{port}/api/v1/generate",
|
||||
"Streaming": "http://{host}:{port}/api/extra/generate/stream",
|
||||
}
|
||||
|
||||
kai_router = APIRouter()
|
||||
extra_kai_router = APIRouter()
|
||||
|
||||
|
||||
def setup():
|
||||
router.include_router(kai_router, prefix="/v1")
|
||||
router.include_router(kai_router, prefix="/latest", include_in_schema=False)
|
||||
router.include_router(extra_kai_router, prefix="/extra")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@kai_router.post(
|
||||
"/generate",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
|
||||
response = await get_generation(data, request)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@extra_kai_router.post(
|
||||
"/generate/stream",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
|
||||
response = EventSourceResponse(stream_generation(data, request), ping=maxsize)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@extra_kai_router.post(
|
||||
"/abort",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def abort_generate(data: AbortRequest) -> AbortResponse:
|
||||
response = await abort_generation(data.genkey)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@extra_kai_router.get(
|
||||
"/generate/check",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
@extra_kai_router.post(
|
||||
"/generate/check",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
|
||||
response = await generation_status(data.genkey)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@kai_router.get(
|
||||
"/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
||||
)
|
||||
async def current_model() -> CurrentModelResponse:
|
||||
"""Fetches the current model and who owns it."""
|
||||
|
||||
current_model_card = get_current_model()
|
||||
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
|
||||
|
||||
|
||||
@extra_kai_router.post(
|
||||
"/tokencount",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
|
||||
raw_tokens = model.container.encode_tokens(data.prompt)
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
return TokenCountResponse(value=len(tokens), ids=tokens)
|
||||
|
||||
|
||||
@kai_router.get(
|
||||
"/config/max_length",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
@kai_router.get(
|
||||
"/config/max_context_length",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
@extra_kai_router.get(
|
||||
"/true_max_context_length",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_max_length() -> MaxLengthResponse:
|
||||
"""Fetches the max length of the model."""
|
||||
|
||||
max_length = model.container.get_model_parameters().get("max_seq_len")
|
||||
return {"value": max_length}
|
||||
|
||||
|
||||
@kai_router.get("/info/version")
|
||||
async def get_version():
|
||||
"""Impersonate KAI United."""
|
||||
|
||||
return {"result": "1.2.5"}
|
||||
|
||||
|
||||
@extra_kai_router.get("/version")
|
||||
async def get_extra_version():
|
||||
"""Impersonate Koboldcpp."""
|
||||
|
||||
return {"result": "KoboldCpp", "version": "1.61"}
|
||||
|
||||
|
||||
@kai_router.get("/config/soft_prompts_list")
|
||||
async def get_available_softprompts():
|
||||
"""Used for KAI compliance."""
|
||||
|
||||
return {"values": []}
|
||||
|
||||
|
||||
@kai_router.get("/config/soft_prompt")
|
||||
async def get_current_softprompt():
|
||||
"""Used for KAI compliance."""
|
||||
|
||||
return {"value": ""}
|
||||
|
||||
|
||||
@kai_router.put("/config/soft_prompt")
|
||||
async def set_current_softprompt():
|
||||
"""Used for KAI compliance."""
|
||||
|
||||
return {}
|
||||
60
endpoints/Kobold/types/generation.py
Normal file
60
endpoints/Kobold/types/generation.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from common import model
|
||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||
from common.utils import flat_map, unwrap
|
||||
|
||||
|
||||
class GenerateRequest(BaseSamplerRequest):
|
||||
prompt: str
|
||||
genkey: Optional[str] = None
|
||||
use_default_badwordsids: Optional[bool] = False
|
||||
dynatemp_range: Optional[float] = Field(
|
||||
default_factory=get_default_sampler_value("dynatemp_range")
|
||||
)
|
||||
|
||||
def to_gen_params(self, **kwargs):
|
||||
# Exl2 uses -1 to include all tokens in repetition penalty
|
||||
if self.penalty_range == 0:
|
||||
self.penalty_range = -1
|
||||
|
||||
if self.dynatemp_range:
|
||||
self.min_temp = self.temperature - self.dynatemp_range
|
||||
self.max_temp = self.temperature + self.dynatemp_range
|
||||
|
||||
# Move badwordsids into banned tokens for generation
|
||||
if self.use_default_badwordsids:
|
||||
bad_words_ids = unwrap(
|
||||
model.container.generation_config.bad_words_ids,
|
||||
model.container.hf_config.get_badwordsids(),
|
||||
)
|
||||
|
||||
if bad_words_ids:
|
||||
self.banned_tokens += flat_map(bad_words_ids)
|
||||
|
||||
return super().to_gen_params(**kwargs)
|
||||
|
||||
|
||||
class GenerateResponseResult(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
results: List[GenerateResponseResult] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamGenerateChunk(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
class AbortRequest(BaseModel):
|
||||
genkey: str
|
||||
|
||||
|
||||
class AbortResponse(BaseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class CheckGenerateRequest(BaseModel):
|
||||
genkey: str
|
||||
9
endpoints/Kobold/types/model.py
Normal file
9
endpoints/Kobold/types/model.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CurrentModelResponse(BaseModel):
|
||||
result: str
|
||||
|
||||
|
||||
class MaxLengthResponse(BaseModel):
|
||||
value: int
|
||||
15
endpoints/Kobold/types/token.py
Normal file
15
endpoints/Kobold/types/token.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
"""Represents a KAI tokenization request."""
|
||||
|
||||
prompt: str
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
"""Represents a KAI tokenization response."""
|
||||
|
||||
value: int
|
||||
ids: List[int]
|
||||
151
endpoints/Kobold/utils/generation.py
Normal file
151
endpoints/Kobold/utils/generation.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import asyncio
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException, Request
|
||||
from loguru import logger
|
||||
from sse_starlette import ServerSentEvent
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
request_disconnect_loop,
|
||||
)
|
||||
from common.utils import unwrap
|
||||
from endpoints.Kobold.types.generation import (
|
||||
AbortResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseResult,
|
||||
StreamGenerateChunk,
|
||||
)
|
||||
|
||||
|
||||
generation_cache = {}
|
||||
|
||||
|
||||
async def override_request_id(request: Request, data: GenerateRequest):
|
||||
"""Overrides the request ID with a KAI genkey if present."""
|
||||
|
||||
if data.genkey:
|
||||
request.state.id = data.genkey
|
||||
|
||||
|
||||
def _create_response(text: str):
|
||||
results = [GenerateResponseResult(text=text)]
|
||||
return GenerateResponse(results=results)
|
||||
|
||||
|
||||
def _create_stream_chunk(text: str):
|
||||
return StreamGenerateChunk(token=text)
|
||||
|
||||
|
||||
async def _stream_collector(data: GenerateRequest, request: Request):
|
||||
"""Common async generator for generation streams."""
|
||||
|
||||
abort_event = asyncio.Event()
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
# Create a new entry in the cache
|
||||
generation_cache[data.genkey] = {"abort": abort_event, "text": ""}
|
||||
|
||||
try:
|
||||
logger.info(f"Received Kobold generation request {data.genkey}")
|
||||
|
||||
generator = model.container.generate_gen(
|
||||
data.prompt, data.genkey, abort_event, **data.to_gen_params()
|
||||
)
|
||||
async for generation in generator:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect(
|
||||
f"Kobold generation {data.genkey} cancelled by user."
|
||||
)
|
||||
|
||||
text = generation.get("text")
|
||||
|
||||
# Update the generation cache with the new chunk
|
||||
if text:
|
||||
generation_cache[data.genkey]["text"] += text
|
||||
yield text
|
||||
|
||||
if "finish_reason" in generation:
|
||||
logger.info(f"Finished streaming Kobold request {data.genkey}")
|
||||
break
|
||||
except CancelledError:
|
||||
# If the request disconnects, break out
|
||||
if not disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect(
|
||||
f"Kobold generation {data.genkey} cancelled by user."
|
||||
)
|
||||
finally:
|
||||
# Cleanup the cache
|
||||
del generation_cache[data.genkey]
|
||||
|
||||
|
||||
async def stream_generation(data: GenerateRequest, request: Request):
|
||||
"""Wrapper for stream generations."""
|
||||
|
||||
# If the genkey doesn't exist, set it to the request ID
|
||||
if not data.genkey:
|
||||
data.genkey = request.state.id
|
||||
|
||||
try:
|
||||
async for chunk in _stream_collector(data, request):
|
||||
response = _create_stream_chunk(chunk)
|
||||
yield ServerSentEvent(
|
||||
event="message", data=response.model_dump_json(), sep="\n"
|
||||
)
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
f"Kobold generation {data.genkey} aborted. "
|
||||
"Please check the server console."
|
||||
)
|
||||
|
||||
|
||||
async def get_generation(data: GenerateRequest, request: Request):
|
||||
"""Wrapper to get a static generation."""
|
||||
|
||||
# If the genkey doesn't exist, set it to the request ID
|
||||
if not data.genkey:
|
||||
data.genkey = request.state.id
|
||||
|
||||
try:
|
||||
full_text = ""
|
||||
async for chunk in _stream_collector(data, request):
|
||||
full_text += chunk
|
||||
|
||||
response = _create_response(full_text)
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
|
||||
|
||||
async def abort_generation(genkey: str):
|
||||
"""Aborts a generation from the cache."""
|
||||
|
||||
abort_event = unwrap(generation_cache.get(genkey), {}).get("abort")
|
||||
if abort_event:
|
||||
abort_event.set()
|
||||
handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.")
|
||||
|
||||
return AbortResponse(success=True)
|
||||
|
||||
|
||||
async def generation_status(genkey: str):
|
||||
"""Fetches the status of a generation from the cache."""
|
||||
|
||||
current_text = unwrap(generation_cache.get(genkey), {}).get("text")
|
||||
if current_text:
|
||||
response = _create_response(current_text)
|
||||
else:
|
||||
response = GenerateResponse()
|
||||
|
||||
return response
|
||||
@@ -1,48 +1,21 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
from typing import Optional
|
||||
|
||||
from common import config, model, gen_logging, sampling
|
||||
from common.auth import check_admin_key, check_api_key, validate_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common import config, model
|
||||
from common.auth import check_api_key
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.OAI.types.auth import AuthPermissionResponse
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.OAI.types.lora import (
|
||||
LoraCard,
|
||||
LoraList,
|
||||
LoraLoadRequest,
|
||||
LoraLoadResponse,
|
||||
)
|
||||
from endpoints.OAI.types.model import (
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelCardParameters,
|
||||
ModelLoadResponse,
|
||||
)
|
||||
from endpoints.OAI.types.sampler_overrides import (
|
||||
SamplerOverrideListResponse,
|
||||
SamplerOverrideSwitchRequest,
|
||||
)
|
||||
from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest
|
||||
from endpoints.OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
TokenEncodeResponse,
|
||||
TokenDecodeRequest,
|
||||
TokenDecodeResponse,
|
||||
)
|
||||
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
||||
from endpoints.OAI.utils.chat_completion import (
|
||||
format_prompt_with_template,
|
||||
generate_chat_completion,
|
||||
@@ -52,25 +25,19 @@ from endpoints.OAI.utils.completion import (
|
||||
generate_completion,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.model import get_model_list, stream_model_load
|
||||
from endpoints.OAI.utils.lora import get_lora_list
|
||||
from endpoints.OAI.utils.embeddings import get_embeddings
|
||||
|
||||
|
||||
api_name = "OAI"
|
||||
router = APIRouter()
|
||||
urls = {
|
||||
"Completions": "http://{host}:{port}/v1/completions",
|
||||
"Chat completions": "http://{host}:{port}/v1/chat/completions",
|
||||
}
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
||||
|
||||
if model.container is None or not (
|
||||
model.container.model_is_loading or model.container.model_loaded
|
||||
):
|
||||
error_message = handle_request_error(
|
||||
"No models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
def setup():
|
||||
return router
|
||||
|
||||
|
||||
# Completions endpoint
|
||||
@@ -106,12 +73,14 @@ async def completion_request(
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(generate_completion(data, model_path))
|
||||
generate_task = asyncio.create_task(
|
||||
generate_completion(data, request, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
generate_task,
|
||||
disconnect_message="Completion generation cancelled by user.",
|
||||
disconnect_message=f"Completion {request.state.id} cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -206,392 +175,28 @@ async def chat_completion_request(
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
generate_chat_completion(prompt, data, model_path)
|
||||
generate_chat_completion(prompt, data, request, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
generate_task,
|
||||
disconnect_message="Chat completion generation cancelled by user.",
|
||||
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# Model list endpoint
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models() -> ModelList:
|
||||
"""Lists all models in the model directory."""
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
|
||||
|
||||
# Currently loaded model endpoint
|
||||
@router.get(
|
||||
"/v1/model",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_current_model() -> ModelCard:
|
||||
"""Returns the currently loaded model."""
|
||||
model_params = model.container.get_model_parameters()
|
||||
draft_model_params = model_params.pop("draft", {})
|
||||
|
||||
if draft_model_params:
|
||||
model_params["draft"] = ModelCard(
|
||||
id=unwrap(draft_model_params.get("name"), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
else:
|
||||
draft_model_params = None
|
||||
|
||||
model_card = ModelCard(
|
||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(model_params),
|
||||
logging=gen_logging.PREFERENCES,
|
||||
)
|
||||
|
||||
if draft_model_params:
|
||||
draft_card = ModelCard(
|
||||
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
|
||||
model_card.parameters.draft = draft_card
|
||||
|
||||
return model_card
|
||||
|
||||
|
||||
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models() -> ModelList:
|
||||
"""Lists all draft models in the model directory."""
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
|
||||
return models
|
||||
|
||||
|
||||
# Load model endpoint
|
||||
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
"""Loads a model into the model container. This returns an SSE stream."""
|
||||
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"A model name was not provided for load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
draft_model_path = None
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
error_message = handle_request_error(
|
||||
"Could not find the draft model name for model load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
draft_model_path = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
"Could not find the model path for load. Check model name or config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
return EventSourceResponse(
|
||||
stream_model_load(data, model_path, draft_model_path), ping=maxsize
|
||||
)
|
||||
|
||||
|
||||
# Unload model endpoint
|
||||
# Embeddings endpoint
|
||||
@router.post(
|
||||
"/v1/model/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
"/v1/embeddings",
|
||||
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
|
||||
)
|
||||
async def unload_model():
|
||||
"""Unloads the currently loaded model."""
|
||||
await model.unload_model(skip_wait=True)
|
||||
|
||||
|
||||
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
|
||||
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
|
||||
"""Downloads a model from HuggingFace."""
|
||||
|
||||
try:
|
||||
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
|
||||
|
||||
# For now, the downloader and request data are 1:1
|
||||
download_path = await run_with_request_disconnect(
|
||||
request,
|
||||
download_task,
|
||||
"Download request cancelled by user. Files have been cleaned up.",
|
||||
)
|
||||
|
||||
return DownloadResponse(download_path=str(download_path))
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_all_loras() -> LoraList:
|
||||
"""Lists all LoRAs in the lora directory."""
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
return loras
|
||||
|
||||
|
||||
# Currently loaded loras endpoint
|
||||
@router.get(
|
||||
"/v1/lora",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_active_loras() -> LoraList:
|
||||
"""Returns the currently loaded loras."""
|
||||
active_loras = LoraList(
|
||||
data=[
|
||||
LoraCard(
|
||||
id=pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||
)
|
||||
for lora in model.container.get_loras()
|
||||
]
|
||||
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
embeddings_task = asyncio.create_task(get_embeddings(data, request))
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
embeddings_task,
|
||||
f"Embeddings request {request.state.id} cancelled by user.",
|
||||
)
|
||||
|
||||
return active_loras
|
||||
|
||||
|
||||
# Load lora endpoint
|
||||
@router.post(
|
||||
"/v1/lora/load",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
||||
"""Loads a LoRA into the model container."""
|
||||
|
||||
if not data.loras:
|
||||
error_message = handle_request_error(
|
||||
"List of loras to load is not found.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
error_message = handle_request_error(
|
||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
load_result = await model.load_loras(
|
||||
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
|
||||
)
|
||||
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(load_result.get("success"), []),
|
||||
failure=unwrap(load_result.get("failure"), []),
|
||||
)
|
||||
|
||||
|
||||
# Unload lora endpoint
|
||||
@router.post(
|
||||
"/v1/lora/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_loras():
|
||||
"""Unloads the currently loaded loras."""
|
||||
|
||||
await model.unload_loras()
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@router.post(
|
||||
"/v1/token/encode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
|
||||
if isinstance(data.text, str):
|
||||
text = data.text
|
||||
else:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True)
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"messages": data.text,
|
||||
"add_generation_prompt": False,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = model.container.prompt_template.render(template_vars)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Decode tokens endpoint
|
||||
@router.post(
|
||||
"/v1/token/decode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
|
||||
"""Decodes tokens into a string."""
|
||||
message = model.container.decode_tokens(data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
|
||||
async def get_key_permission(
|
||||
x_admin_key: Optional[str] = Header(None),
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> AuthPermissionResponse:
|
||||
"""
|
||||
Gets the access level/permission of a provided key in headers.
|
||||
|
||||
Priority:
|
||||
- X-api-key
|
||||
- X-admin-key
|
||||
- Authorization
|
||||
"""
|
||||
|
||||
test_key = coalesce(x_admin_key, x_api_key, authorization)
|
||||
|
||||
try:
|
||||
permission = await validate_key_permission(test_key)
|
||||
return AuthPermissionResponse(permission=permission)
|
||||
except ValueError as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_templates() -> TemplateList:
|
||||
templates = get_all_templates()
|
||||
template_strings = [template.stem for template in templates]
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template"""
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"New template name not found.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_template():
|
||||
"""Unloads the currently selected template"""
|
||||
|
||||
model.container.prompt_template = None
|
||||
|
||||
|
||||
# Sampler override endpoints
|
||||
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides() -> SamplerOverrideListResponse:
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return SamplerOverrideListResponse(
|
||||
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
"""Switch the currently loaded override preset"""
|
||||
|
||||
if data.preset:
|
||||
try:
|
||||
sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"Sampler override preset with name {data.preset} does not exist. "
|
||||
+ "Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
elif data.overrides:
|
||||
sampling.overrides_from_dict(data.overrides)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"A sampler override preset or dictionary wasn't provided.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
sampling.overrides_from_dict({})
|
||||
|
||||
@@ -65,3 +65,4 @@ class ChatCompletionStreamChunk(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "chat.completion.chunk"
|
||||
usage: Optional[UsageStats] = None
|
||||
|
||||
@@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel):
|
||||
type: str = "text"
|
||||
|
||||
|
||||
class ChatCompletionStreamOptions(BaseModel):
|
||||
include_usage: Optional[bool] = False
|
||||
|
||||
|
||||
class CommonCompletionRequest(BaseSamplerRequest):
|
||||
"""Represents a common completion request."""
|
||||
|
||||
@@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
||||
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[ChatCompletionStreamOptions] = None
|
||||
logprobs: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("logprobs", 0)
|
||||
)
|
||||
|
||||
42
endpoints/OAI/types/embedding.py
Normal file
42
endpoints/OAI/types/embedding.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: List[str] = Field(
|
||||
..., description="List of input texts to generate embeddings for."
|
||||
)
|
||||
encoding_format: str = Field(
|
||||
"float",
|
||||
description="Encoding format for the embeddings. "
|
||||
"Can be 'float' or 'base64'.",
|
||||
)
|
||||
model: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the embedding model to use. "
|
||||
"If not provided, the default model will be used.",
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingObject(BaseModel):
|
||||
object: str = Field("embedding", description="Type of the object.")
|
||||
embedding: List[float] = Field(
|
||||
..., description="Embedding values as a list of floats."
|
||||
)
|
||||
index: int = Field(
|
||||
..., description="Index of the input text corresponding to " "the embedding."
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
object: str = Field("list", description="Type of the response object.")
|
||||
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
|
||||
model: str = Field(..., description="Name of the embedding model used.")
|
||||
usage: UsageInfo = Field(..., description="Information about token usage.")
|
||||
@@ -5,7 +5,6 @@ import pathlib
|
||||
from asyncio import CancelledError
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
@@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
from endpoints.OAI.utils.completion import _stream_collector
|
||||
|
||||
|
||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
def _create_response(
|
||||
request_id: str, generations: List[dict], model_name: Optional[str]
|
||||
):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
@@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
choices.append(choice)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
@@ -90,25 +93,40 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
|
||||
|
||||
def _create_stream_chunk(
|
||||
const_id: str,
|
||||
request_id: str,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
is_usage_chunk: bool = False,
|
||||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
|
||||
index = generation.get("index")
|
||||
logprob_response = None
|
||||
choices = []
|
||||
usage_stats = None
|
||||
|
||||
if "finish_reason" in generation:
|
||||
if is_usage_chunk:
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
|
||||
usage_stats = UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
elif "finish_reason" in generation:
|
||||
choice = ChatCompletionStreamChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
)
|
||||
|
||||
choices.append(choice)
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), {})
|
||||
@@ -132,8 +150,13 @@ def _create_stream_chunk(
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id, choices=[choice], model=unwrap(model_name, "")
|
||||
id=f"chatcmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=usage_stats,
|
||||
)
|
||||
|
||||
return chunk
|
||||
@@ -215,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
||||
await gen_queue.put(generation)
|
||||
|
||||
if "finish_reason" in generation:
|
||||
break
|
||||
except Exception as e:
|
||||
await gen_queue.put(e)
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
abort_event = asyncio.Event()
|
||||
gen_queue = asyncio.Queue()
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
logger.info(f"Received chat completion streaming request {request.state.id}")
|
||||
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
for n in range(0, data.n):
|
||||
@@ -257,7 +259,14 @@ async def stream_generate_chat_completion(
|
||||
task_gen_params = gen_params
|
||||
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
|
||||
_stream_collector(
|
||||
n,
|
||||
gen_queue,
|
||||
prompt,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
**task_gen_params,
|
||||
)
|
||||
)
|
||||
|
||||
gen_tasks.append(gen_task)
|
||||
@@ -266,7 +275,9 @@ async def stream_generate_chat_completion(
|
||||
while True:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
handle_request_disconnect(
|
||||
f"Chat completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
|
||||
@@ -274,17 +285,35 @@ async def stream_generate_chat_completion(
|
||||
if isinstance(generation, Exception):
|
||||
raise generation
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
response = _create_stream_chunk(
|
||||
request.state.id, generation, model_path.name
|
||||
)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Check if all tasks are completed
|
||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||
# Send a usage chunk
|
||||
if data.stream_options and data.stream_options.include_usage:
|
||||
usage_chunk = _create_stream_chunk(
|
||||
request.state.id,
|
||||
generation,
|
||||
model_path.name,
|
||||
is_usage_chunk=True,
|
||||
)
|
||||
yield usage_chunk.model_dump_json()
|
||||
|
||||
logger.info(
|
||||
f"Finished chat completion streaming request {request.state.id}"
|
||||
)
|
||||
|
||||
yield "[DONE]"
|
||||
break
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||
if not disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Chat completion aborted. Please check the server console."
|
||||
@@ -292,7 +321,7 @@ async def stream_generate_chat_completion(
|
||||
|
||||
|
||||
async def generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
gen_params = data.to_gen_params()
|
||||
@@ -307,16 +336,23 @@ async def generate_chat_completion(
|
||||
task_gen_params = gen_params
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
prompt, request.state.id, **task_gen_params
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
logger.info(f"Finished chat completion request {request.state.id}")
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Chat completion aborted. Maybe the model was unloaded? "
|
||||
f"Chat completion {request.state.id} aborted. "
|
||||
"Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from copy import deepcopy
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import List, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
@@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import (
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
|
||||
def _create_response(
|
||||
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
|
||||
):
|
||||
"""Create a completion response from the provided choices."""
|
||||
|
||||
# Convert the single choice object into a list
|
||||
@@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
response = CompletionResponse(
|
||||
id=f"cmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=model_name,
|
||||
usage=UsageStats(
|
||||
@@ -77,13 +82,16 @@ async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
||||
@@ -106,6 +114,8 @@ async def stream_generate_completion(
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
logger.info(f"Received streaming completion request {request.state.id}")
|
||||
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
for n in range(0, data.n):
|
||||
@@ -116,7 +126,12 @@ async def stream_generate_completion(
|
||||
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(
|
||||
n, gen_queue, data.prompt, abort_event, **task_gen_params
|
||||
n,
|
||||
gen_queue,
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
**task_gen_params,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -126,7 +141,9 @@ async def stream_generate_completion(
|
||||
while True:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
handle_request_disconnect(
|
||||
f"Completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
|
||||
@@ -134,31 +151,39 @@ async def stream_generate_completion(
|
||||
if isinstance(generation, Exception):
|
||||
raise generation
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
response = _create_response(request.state.id, generation, model_path.name)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Check if all tasks are completed
|
||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||
yield "[DONE]"
|
||||
logger.info(f"Finished streaming completion request {request.state.id}")
|
||||
break
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
if not disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect(
|
||||
f"Completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
f"Completion {request.state.id} aborted. Please check the server console."
|
||||
)
|
||||
|
||||
|
||||
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
async def generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Non-streaming generate for completions"""
|
||||
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
try:
|
||||
logger.info(f"Recieved completion request {request.state.id}")
|
||||
|
||||
for n in range(0, data.n):
|
||||
# Deepcopy gen params above the first index
|
||||
# to ensure nested structures aren't shared
|
||||
@@ -169,17 +194,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(data.prompt, **task_gen_params)
|
||||
model.container.generate(
|
||||
data.prompt, request.state.id, **task_gen_params
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
logger.info(f"Finished completion request {request.state.id}")
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Completion aborted. Maybe the model was unloaded? "
|
||||
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
|
||||
64
endpoints/OAI/utils/embeddings.py
Normal file
64
endpoints/OAI/utils/embeddings.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
This file is derived from
|
||||
[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
|
||||
and modified.
|
||||
The changes introduced are: Suppression of progress bar,
|
||||
typing/pydantic classes moved into this file,
|
||||
embeddings function declared async.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from fastapi import Request
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from endpoints.OAI.types.embedding import (
|
||||
EmbeddingObject,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
"""
|
||||
Converts the provided list to a float32 array for OpenAI
|
||||
Ex. float_array = np.array(float_list, dtype="float32")
|
||||
"""
|
||||
|
||||
# Encode raw bytes into base64
|
||||
encoded_bytes = base64.b64encode(float_array.tobytes())
|
||||
|
||||
# Turn raw base64 encoded bytes into ASCII
|
||||
ascii_string = encoded_bytes.decode("ascii")
|
||||
return ascii_string
|
||||
|
||||
|
||||
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
|
||||
logger.info(f"Recieved embeddings request {request.state.id}")
|
||||
embedding_data = await model.embeddings_container.generate(data.input)
|
||||
|
||||
# OAI expects a return of base64 if the input is base64
|
||||
embedding_object = [
|
||||
EmbeddingObject(
|
||||
embedding=float_list_to_base64(emb)
|
||||
if data.encoding_format == "base64"
|
||||
else emb.tolist(),
|
||||
index=n,
|
||||
)
|
||||
for n, emb in enumerate(embedding_data.get("embeddings"))
|
||||
]
|
||||
|
||||
usage = embedding_data.get("usage")
|
||||
response = EmbeddingsResponse(
|
||||
data=embedding_object,
|
||||
model=model_path.name,
|
||||
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
|
||||
)
|
||||
|
||||
logger.info(f"Finished embeddings request {request.state.id}")
|
||||
|
||||
return response
|
||||
@@ -1,14 +0,0 @@
|
||||
import pathlib
|
||||
|
||||
from endpoints.OAI.types.lora import LoraCard, LoraList
|
||||
|
||||
|
||||
def get_lora_list(lora_path: pathlib.Path):
|
||||
"""Get the list of Lora cards from the provided path."""
|
||||
lora_list = LoraList()
|
||||
for path in lora_path.iterdir():
|
||||
if path.is_dir():
|
||||
lora_card = LoraCard(id=path.name)
|
||||
lora_list.data.append(lora_card)
|
||||
|
||||
return lora_list
|
||||
521
endpoints/core/router.py
Normal file
521
endpoints/core/router.py
Normal file
@@ -0,0 +1,521 @@
|
||||
import asyncio
|
||||
import pathlib
|
||||
from sys import maxsize
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from common import config, model, sampling
|
||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.types.auth import AuthPermissionResponse
|
||||
from endpoints.core.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
from endpoints.core.types.model import (
|
||||
EmbeddingModelLoadRequest,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelLoadResponse,
|
||||
)
|
||||
from endpoints.core.types.sampler_overrides import (
|
||||
SamplerOverrideListResponse,
|
||||
SamplerOverrideSwitchRequest,
|
||||
)
|
||||
from endpoints.core.types.template import TemplateList, TemplateSwitchRequest
|
||||
from endpoints.core.types.token import (
|
||||
TokenDecodeRequest,
|
||||
TokenDecodeResponse,
|
||||
TokenEncodeRequest,
|
||||
TokenEncodeResponse,
|
||||
)
|
||||
from endpoints.core.utils.lora import get_active_loras, get_lora_list
|
||||
from endpoints.core.utils.model import (
|
||||
get_current_model,
|
||||
get_current_model_list,
|
||||
get_model_list,
|
||||
stream_model_load,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Model list endpoint
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models(request: Request) -> ModelList:
|
||||
"""
|
||||
Lists all models in the model directory.
|
||||
|
||||
Requires an admin key to see all models.
|
||||
"""
|
||||
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
else:
|
||||
models = await get_current_model_list()
|
||||
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
return models
|
||||
|
||||
|
||||
# Currently loaded model endpoint
|
||||
@router.get(
|
||||
"/v1/model",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def current_model() -> ModelCard:
|
||||
"""Returns the currently loaded model."""
|
||||
|
||||
return get_current_model()
|
||||
|
||||
|
||||
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models(request: Request) -> ModelList:
|
||||
"""
|
||||
Lists all draft models in the model directory.
|
||||
|
||||
Requires an admin key to see all draft models.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
else:
|
||||
models = await get_current_model_list(is_draft=True)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
# Load model endpoint
|
||||
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
"""Loads a model into the model container. This returns an SSE stream."""
|
||||
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"A model name was not provided for load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
draft_model_path = None
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
error_message = handle_request_error(
|
||||
"Could not find the draft model name for model load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
draft_model_path = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
"Could not find the model path for load. Check model name or config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
return EventSourceResponse(
|
||||
stream_model_load(data, model_path, draft_model_path), ping=maxsize
|
||||
)
|
||||
|
||||
|
||||
# Unload model endpoint
|
||||
@router.post(
|
||||
"/v1/model/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_model():
|
||||
"""Unloads the currently loaded model."""
|
||||
await model.unload_model(skip_wait=True)
|
||||
|
||||
|
||||
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
|
||||
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
|
||||
"""Downloads a model from HuggingFace."""
|
||||
|
||||
try:
|
||||
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
|
||||
|
||||
# For now, the downloader and request data are 1:1
|
||||
download_path = await run_with_request_disconnect(
|
||||
request,
|
||||
download_task,
|
||||
"Download request cancelled by user. Files have been cleaned up.",
|
||||
)
|
||||
|
||||
return DownloadResponse(download_path=str(download_path))
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_all_loras(request: Request) -> LoraList:
|
||||
"""
|
||||
Lists all LoRAs in the lora directory.
|
||||
|
||||
Requires an admin key to see all LoRAs.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
else:
|
||||
loras = get_active_loras()
|
||||
|
||||
return loras
|
||||
|
||||
|
||||
# Currently loaded loras endpoint
|
||||
@router.get(
|
||||
"/v1/lora",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def active_loras() -> LoraList:
|
||||
"""Returns the currently loaded loras."""
|
||||
|
||||
return get_active_loras()
|
||||
|
||||
|
||||
# Load lora endpoint
|
||||
@router.post(
|
||||
"/v1/lora/load",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
||||
"""Loads a LoRA into the model container."""
|
||||
|
||||
if not data.loras:
|
||||
error_message = handle_request_error(
|
||||
"List of loras to load is not found.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
if not lora_dir.exists():
|
||||
error_message = handle_request_error(
|
||||
"A parent lora directory does not exist for load. Check your config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
load_result = await model.load_loras(
|
||||
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
|
||||
)
|
||||
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(load_result.get("success"), []),
|
||||
failure=unwrap(load_result.get("failure"), []),
|
||||
)
|
||||
|
||||
|
||||
# Unload lora endpoint
|
||||
@router.post(
|
||||
"/v1/lora/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_loras():
|
||||
"""Unloads the currently loaded loras."""
|
||||
|
||||
await model.unload_loras()
|
||||
|
||||
|
||||
@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_embedding_models(request: Request) -> ModelList:
|
||||
"""
|
||||
Lists all embedding models in the model directory.
|
||||
|
||||
Requires an admin key to see all embedding models.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
embedding_model_dir = unwrap(
|
||||
config.embeddings_config().get("embedding_model_dir"), "models"
|
||||
)
|
||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||
|
||||
models = get_model_list(embedding_model_path.resolve())
|
||||
else:
|
||||
models = await get_current_model_list(model_type="embedding")
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/model/embedding",
|
||||
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
|
||||
)
|
||||
async def get_embedding_model() -> ModelCard:
|
||||
"""Returns the currently loaded embedding model."""
|
||||
models = await get_current_model_list(model_type="embedding")
|
||||
|
||||
return models.data[0]
|
||||
|
||||
|
||||
@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_embedding_model(
|
||||
request: Request, data: EmbeddingModelLoadRequest
|
||||
) -> ModelLoadResponse:
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"A model name was not provided for load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(
|
||||
unwrap(config.model_config().get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
|
||||
if not embedding_model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
"Could not find the embedding model path for load. "
|
||||
+ "Check model name or config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
load_task = asyncio.create_task(
|
||||
model.load_embedding_model(embedding_model_path, **data.model_dump())
|
||||
)
|
||||
await run_with_request_disconnect(
|
||||
request, load_task, "Embedding model load request cancelled by user."
|
||||
)
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
response = ModelLoadResponse(
|
||||
model_type="embedding_model", module=1, modules=1, status="finished"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/model/embedding/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)],
|
||||
)
|
||||
async def unload_embedding_model():
|
||||
"""Unloads the current embedding model."""
|
||||
|
||||
await model.unload_embedding_model()
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@router.post(
|
||||
"/v1/token/encode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
|
||||
if isinstance(data.text, str):
|
||||
text = data.text
|
||||
else:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True)
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"messages": data.text,
|
||||
"add_generation_prompt": False,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = model.container.prompt_template.render(template_vars)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Decode tokens endpoint
|
||||
@router.post(
|
||||
"/v1/token/decode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
|
||||
"""Decodes tokens into a string."""
|
||||
|
||||
message = model.container.decode_tokens(data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
|
||||
async def key_permission(request: Request) -> AuthPermissionResponse:
|
||||
"""
|
||||
Gets the access level/permission of a provided key in headers.
|
||||
|
||||
Priority:
|
||||
- X-admin-key
|
||||
- X-api-key
|
||||
- Authorization
|
||||
"""
|
||||
|
||||
try:
|
||||
permission = get_key_permission(request)
|
||||
return AuthPermissionResponse(permission=permission)
|
||||
except ValueError as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_templates(request: Request) -> TemplateList:
|
||||
"""
|
||||
Get a list of all templates.
|
||||
|
||||
Requires an admin key to see all templates.
|
||||
"""
|
||||
|
||||
template_strings = []
|
||||
if get_key_permission(request) == "admin":
|
||||
templates = get_all_templates()
|
||||
template_strings = [template.stem for template in templates]
|
||||
else:
|
||||
if model.container and model.container.prompt_template:
|
||||
template_strings.append(model.container.prompt_template.name)
|
||||
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template."""
|
||||
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"New template name not found.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_template():
|
||||
"""Unloads the currently selected template"""
|
||||
|
||||
model.container.prompt_template = None
|
||||
|
||||
|
||||
# Sampler override endpoints
|
||||
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse:
|
||||
"""
|
||||
List all currently applied sampler overrides.
|
||||
|
||||
Requires an admin key to see all override presets.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
presets = sampling.get_all_presets()
|
||||
else:
|
||||
presets = []
|
||||
|
||||
return SamplerOverrideListResponse(
|
||||
presets=presets, **sampling.overrides_container.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||
"""Switch the currently loaded override preset"""
|
||||
|
||||
if data.preset:
|
||||
try:
|
||||
sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"Sampler override preset with name {data.preset} does not exist. "
|
||||
+ "Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from e
|
||||
elif data.overrides:
|
||||
sampling.overrides_from_dict(data.overrides)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"A sampler override preset or dictionary wasn't provided.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
async def unload_sampler_override():
|
||||
"""Unloads the currently selected override preset"""
|
||||
|
||||
sampling.overrides_from_dict({})
|
||||
@@ -17,6 +17,7 @@ class DownloadRequest(BaseModel):
|
||||
include: List[str] = Field(default_factory=_generate_include_list)
|
||||
exclude: List[str] = Field(default_factory=list)
|
||||
chunk_limit: Optional[int] = None
|
||||
timeout: Optional[int] = None
|
||||
|
||||
|
||||
class DownloadResponse(BaseModel):
|
||||
@@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
|
||||
# Config arguments
|
||||
draft_rope_scale: Optional[float] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_scale", 1.0, is_draft=True
|
||||
"draft_rope_scale", 1.0, model_type="draft"
|
||||
)
|
||||
)
|
||||
draft_rope_alpha: Optional[float] = Field(
|
||||
description="Automatically calculated if not present",
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_alpha", None, is_draft=True
|
||||
"draft_rope_alpha", None, model_type="draft"
|
||||
),
|
||||
examples=[1.0],
|
||||
)
|
||||
draft_cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_cache_mode", "FP16", is_draft=True
|
||||
"draft_cache_mode", "FP16", model_type="draft"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -137,6 +137,15 @@ class ModelLoadRequest(BaseModel):
|
||||
skip_queue: Optional[bool] = False
|
||||
|
||||
|
||||
class EmbeddingModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
embeddings_device: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"embeddings_device", model_type="embedding"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
"""Represents a model load response."""
|
||||
|
||||
30
endpoints/core/utils/lora.py
Normal file
30
endpoints/core/utils/lora.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import pathlib
|
||||
|
||||
from common import model
|
||||
from endpoints.core.types.lora import LoraCard, LoraList
|
||||
|
||||
|
||||
def get_lora_list(lora_path: pathlib.Path):
|
||||
"""Get the list of Lora cards from the provided path."""
|
||||
lora_list = LoraList()
|
||||
for path in lora_path.iterdir():
|
||||
if path.is_dir():
|
||||
lora_card = LoraCard(id=path.name)
|
||||
lora_list.data.append(lora_card)
|
||||
|
||||
return lora_list
|
||||
|
||||
|
||||
def get_active_loras():
|
||||
if model.container:
|
||||
active_loras = [
|
||||
LoraCard(
|
||||
id=pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||
)
|
||||
for lora in model.container.get_loras()
|
||||
]
|
||||
else:
|
||||
active_loras = []
|
||||
|
||||
return LoraList(data=active_loras)
|
||||
@@ -2,11 +2,12 @@ import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
|
||||
from common import model
|
||||
from common import gen_logging, model
|
||||
from common.networking import get_generator_error, handle_request_disconnect
|
||||
|
||||
from endpoints.OAI.types.model import (
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.types.model import (
|
||||
ModelCard,
|
||||
ModelCardParameters,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelLoadResponse,
|
||||
@@ -31,6 +32,61 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
|
||||
return model_card_list
|
||||
|
||||
|
||||
async def get_current_model_list(model_type: str = "model"):
|
||||
"""
|
||||
Gets the current model in list format and with path only.
|
||||
|
||||
Unified for fetching both models and embedding models.
|
||||
"""
|
||||
|
||||
current_models = []
|
||||
model_path = None
|
||||
|
||||
# Make sure the model container exists
|
||||
if model_type == "model" or model_type == "draft":
|
||||
if model.container:
|
||||
model_path = model.container.get_model_path(model_type == "draft")
|
||||
elif model_type == "embedding":
|
||||
if model.embeddings_container:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
|
||||
if model_path:
|
||||
current_models.append(ModelCard(id=model_path.name))
|
||||
|
||||
return ModelList(data=current_models)
|
||||
|
||||
|
||||
def get_current_model():
|
||||
"""Gets the current model with all parameters."""
|
||||
|
||||
model_params = model.container.get_model_parameters()
|
||||
draft_model_params = model_params.pop("draft", {})
|
||||
|
||||
if draft_model_params:
|
||||
model_params["draft"] = ModelCard(
|
||||
id=unwrap(draft_model_params.get("name"), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
else:
|
||||
draft_model_params = None
|
||||
|
||||
model_card = ModelCard(
|
||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(model_params),
|
||||
logging=gen_logging.PREFERENCES,
|
||||
)
|
||||
|
||||
if draft_model_params:
|
||||
draft_card = ModelCard(
|
||||
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
|
||||
model_card.parameters.draft = draft_card
|
||||
|
||||
return model_card
|
||||
|
||||
|
||||
async def stream_model_load(
|
||||
data: ModelLoadRequest,
|
||||
model_path: pathlib.Path,
|
||||
@@ -1,40 +1,74 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from common import config
|
||||
from common.logger import UVICORN_LOG_CONFIG
|
||||
from endpoints.OAI.router import router as OAIRouter
|
||||
|
||||
app = FastAPI(
|
||||
title="TabbyAPI",
|
||||
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
|
||||
description=(
|
||||
"This docs page is not meant to send requests! Please use a service "
|
||||
"like Postman or a frontend UI."
|
||||
),
|
||||
)
|
||||
|
||||
# ALlow CORS requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
from common.networking import get_global_depends
|
||||
from common.utils import unwrap
|
||||
from endpoints.Kobold import router as KoboldRouter
|
||||
from endpoints.OAI import router as OAIRouter
|
||||
from endpoints.core.router import router as CoreRouter
|
||||
|
||||
|
||||
def setup_app():
|
||||
def setup_app(host: Optional[str] = None, port: Optional[int] = None):
|
||||
"""Includes the correct routers for startup"""
|
||||
|
||||
app.include_router(OAIRouter)
|
||||
app = FastAPI(
|
||||
title="TabbyAPI",
|
||||
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
|
||||
description=(
|
||||
"This docs page is not meant to send requests! Please use a service "
|
||||
"like Postman or a frontend UI."
|
||||
),
|
||||
dependencies=get_global_depends(),
|
||||
)
|
||||
|
||||
# ALlow CORS requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
api_servers = unwrap(config.network_config().get("api_servers"), [])
|
||||
|
||||
# Map for API id to server router
|
||||
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
||||
|
||||
# Include the OAI api by default
|
||||
if api_servers:
|
||||
for server in api_servers:
|
||||
selected_server = router_mapping.get(server.lower())
|
||||
|
||||
if selected_server:
|
||||
app.include_router(selected_server.setup())
|
||||
|
||||
logger.info(f"Starting {selected_server.api_name} API")
|
||||
for path, url in selected_server.urls.items():
|
||||
formatted_url = url.format(host=host, port=port)
|
||||
logger.info(f"{path}: {formatted_url}")
|
||||
else:
|
||||
app.include_router(OAIRouter.setup())
|
||||
for path, url in OAIRouter.urls.items():
|
||||
formatted_url = url.format(host=host, port=port)
|
||||
logger.info(f"{path}: {formatted_url}")
|
||||
|
||||
# Include core API request paths
|
||||
app.include_router(CoreRouter)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def export_openapi():
|
||||
"""Function to return the OpenAPI JSON from the API server"""
|
||||
|
||||
setup_app()
|
||||
app = setup_app()
|
||||
return app.openapi()
|
||||
|
||||
|
||||
@@ -43,17 +77,21 @@ async def start_api(host: str, port: int):
|
||||
|
||||
# TODO: Move OAI API to a separate folder
|
||||
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
|
||||
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||
# logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||
# logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||
|
||||
# Setup app
|
||||
setup_app()
|
||||
app = setup_app(host, port)
|
||||
|
||||
# Get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_config=UVICORN_LOG_CONFIG,
|
||||
loop=loop,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
|
||||
143
main.py
143
main.py
@@ -1,10 +1,10 @@
|
||||
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
||||
|
||||
import asyncio
|
||||
import aiofiles
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import signal
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
@@ -23,51 +23,8 @@ if not do_export_openapi:
|
||||
from backends.exllamav2.utils import check_exllama_version
|
||||
|
||||
|
||||
async def entrypoint(args: Optional[dict] = None):
|
||||
"""Entry function for program startup"""
|
||||
|
||||
setup_logger()
|
||||
|
||||
# Set up signal aborting
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
if os.getenv("EXPORT_OPENAPI", "").lower() in ("true", "1"):
|
||||
openapi_json = export_openapi()
|
||||
|
||||
async with aiofiles.open("openapi.json", "w") as f:
|
||||
await f.write(json.dumps(openapi_json))
|
||||
logger.info("Successfully wrote OpenAPI spec to openapi.json")
|
||||
|
||||
return
|
||||
|
||||
# Load from YAML config
|
||||
config.from_file(pathlib.Path("config.yml"))
|
||||
|
||||
# Parse and override config from args
|
||||
if args is None:
|
||||
parser = init_argparser()
|
||||
args = convert_args_to_dict(parser.parse_args(), parser)
|
||||
|
||||
config.from_args(args)
|
||||
|
||||
developer_config = config.developer_config()
|
||||
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
# Skip if launching unsafely
|
||||
|
||||
if unwrap(developer_config.get("unsafe_launch"), False):
|
||||
logger.warning(
|
||||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||
"If you aren't a developer, please keep this off!"
|
||||
)
|
||||
else:
|
||||
check_exllama_version()
|
||||
|
||||
# Enable CUDA malloc backend
|
||||
if unwrap(developer_config.get("cuda_malloc_backend"), False):
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||
logger.warning("Enabled the experimental CUDA malloc backend.")
|
||||
async def entrypoint_async():
|
||||
"""Async entry function for program startup"""
|
||||
|
||||
network_config = config.network_config()
|
||||
|
||||
@@ -97,7 +54,7 @@ async def entrypoint(args: Optional[dict] = None):
|
||||
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
|
||||
|
||||
# Override the generation log options if given
|
||||
log_config = config.gen_logging_config()
|
||||
log_config = config.logging_config()
|
||||
if log_config:
|
||||
gen_logging.update_from_dict(log_config)
|
||||
|
||||
@@ -128,8 +85,98 @@ async def entrypoint(args: Optional[dict] = None):
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
await model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
||||
# If an initial embedding model name is specified, create a separate container
|
||||
# and load the model
|
||||
embedding_config = config.embeddings_config()
|
||||
embedding_model_name = embedding_config.get("embedding_model_name")
|
||||
if embedding_model_name:
|
||||
embedding_model_path = pathlib.Path(
|
||||
unwrap(embedding_config.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_path / embedding_model_name
|
||||
|
||||
try:
|
||||
await model.load_embedding_model(embedding_model_path, **embedding_config)
|
||||
except ImportError as ex:
|
||||
logger.error(ex.msg)
|
||||
|
||||
await start_api(host, port)
|
||||
|
||||
|
||||
def entrypoint(arguments: Optional[dict] = None):
|
||||
setup_logger()
|
||||
|
||||
# Set up signal aborting
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Load from YAML config
|
||||
config.from_file(pathlib.Path("config.yml"))
|
||||
|
||||
# Parse and override config from args
|
||||
if arguments is None:
|
||||
parser = init_argparser()
|
||||
arguments = convert_args_to_dict(parser.parse_args(), parser)
|
||||
|
||||
config.from_args(arguments)
|
||||
|
||||
if do_export_openapi:
|
||||
openapi_json = export_openapi()
|
||||
|
||||
with open("openapi.json", "w") as f:
|
||||
f.write(json.dumps(openapi_json))
|
||||
logger.info("Successfully wrote OpenAPI spec to openapi.json")
|
||||
|
||||
return
|
||||
|
||||
developer_config = config.developer_config()
|
||||
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
# Skip if launching unsafely
|
||||
|
||||
if unwrap(developer_config.get("unsafe_launch"), False):
|
||||
logger.warning(
|
||||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||
"If you aren't a developer, please keep this off!"
|
||||
)
|
||||
else:
|
||||
check_exllama_version()
|
||||
|
||||
# Enable CUDA malloc backend
|
||||
if unwrap(developer_config.get("cuda_malloc_backend"), False):
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||
logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.")
|
||||
|
||||
# Use Uvloop/Winloop
|
||||
if unwrap(developer_config.get("uvloop"), False):
|
||||
if platform.system() == "Windows":
|
||||
from winloop import install
|
||||
else:
|
||||
from uvloop import install
|
||||
|
||||
# Set loop event policy
|
||||
install()
|
||||
|
||||
logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.")
|
||||
|
||||
# Set the process priority
|
||||
if unwrap(developer_config.get("realtime_process_priority"), False):
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
if platform.system() == "Windows":
|
||||
current_process.nice(psutil.REALTIME_PRIORITY_CLASS)
|
||||
else:
|
||||
current_process.nice(psutil.IOPRIO_CLASS_RT)
|
||||
|
||||
logger.warning(
|
||||
"EXPERIMENTAL: Process priority set to Realtime. \n"
|
||||
"If you're not running on administrator/sudo, the priority is set to high."
|
||||
)
|
||||
|
||||
# Enter into the async event loop
|
||||
asyncio.run(entrypoint_async())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(entrypoint())
|
||||
entrypoint()
|
||||
|
||||
@@ -16,7 +16,7 @@ version = "0.0.1"
|
||||
description = "An OAI compatible exllamav2 API that's both lightweight and fast"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"fastapi >= 0.110.0",
|
||||
"fastapi-slim >= 0.110.0",
|
||||
"pydantic >= 2.0.0",
|
||||
"PyYAML",
|
||||
"rich",
|
||||
@@ -28,13 +28,21 @@ dependencies = [
|
||||
"tokenizers",
|
||||
"lm-format-enforcer >= 0.9.6",
|
||||
"aiofiles",
|
||||
"aiohttp",
|
||||
"huggingface_hub",
|
||||
"psutil",
|
||||
"httptools>=0.5.0",
|
||||
|
||||
# Improved asyncio loops
|
||||
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"winloop ; platform_system == 'Windows'",
|
||||
|
||||
# TEMP: Remove once 2.x is fixed in upstream
|
||||
"numpy < 2.0.0",
|
||||
|
||||
# TODO: Maybe move these to a downloader feature?
|
||||
"aiohttp",
|
||||
"huggingface_hub",
|
||||
# For python 3.12
|
||||
"fastparquet @ https://github.com/theroyallab/fastparquet/releases/download/v2024.5.0/fastparquet-0.1.dev837-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"setuptools ; python_version == '3.12'"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -43,7 +51,9 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
extras = [
|
||||
# Heavy dependencies that aren't for everyday use
|
||||
"outlines"
|
||||
"outlines",
|
||||
"infinity-emb",
|
||||
"sentence-transformers",
|
||||
]
|
||||
dev = [
|
||||
"ruff == 0.3.2"
|
||||
@@ -58,22 +68,22 @@ cu121 = [
|
||||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
|
||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
]
|
||||
cu118 = [
|
||||
# Torch
|
||||
@@ -85,17 +95,17 @@ cu118 = [
|
||||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
]
|
||||
amd = [
|
||||
# Torch triton for ROCm
|
||||
@@ -109,9 +119,9 @@ amd = [
|
||||
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
]
|
||||
|
||||
# MARK: Ruff options
|
||||
|
||||
@@ -19,3 +19,5 @@ if exist "%CONDA_PREFIX%" (
|
||||
|
||||
:: Call the python script with batch args
|
||||
call python start.py %*
|
||||
|
||||
pause
|
||||
|
||||
184
start.py
184
start.py
@@ -1,17 +1,21 @@
|
||||
"""Utility to automatically upgrade and start the API"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
from shutil import copyfile
|
||||
import traceback
|
||||
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
|
||||
|
||||
start_options = {}
|
||||
|
||||
|
||||
def get_user_choice(question: str, options_dict: dict):
|
||||
"""
|
||||
Gets user input in a commandline script.
|
||||
@@ -40,36 +44,24 @@ def get_install_features(lib_name: str = None):
|
||||
"""Fetches the appropriate requirements file depending on the GPU"""
|
||||
install_features = None
|
||||
possible_features = ["cu121", "cu118", "amd"]
|
||||
saved_lib_path = pathlib.Path("gpu_lib.txt")
|
||||
|
||||
if lib_name:
|
||||
print("Overriding GPU lib name from args.")
|
||||
else:
|
||||
# Try getting the GPU lib from file
|
||||
if saved_lib_path.exists():
|
||||
with open(saved_lib_path.resolve(), "r") as f:
|
||||
lib_name = f.readline().strip()
|
||||
else:
|
||||
# Ask the user for the GPU lib
|
||||
gpu_lib_choices = {
|
||||
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
|
||||
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
|
||||
"C": {"pretty": "AMD", "internal": "amd"},
|
||||
}
|
||||
user_input = get_user_choice(
|
||||
"Select your GPU. If you don't know, select Cuda 12.x (A)",
|
||||
gpu_lib_choices,
|
||||
)
|
||||
if not lib_name:
|
||||
# Ask the user for the GPU lib
|
||||
gpu_lib_choices = {
|
||||
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
|
||||
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
|
||||
"C": {"pretty": "AMD", "internal": "amd"},
|
||||
}
|
||||
user_input = get_user_choice(
|
||||
"Select your GPU. If you don't know, select Cuda 12.x (A)",
|
||||
gpu_lib_choices,
|
||||
)
|
||||
|
||||
lib_name = gpu_lib_choices.get(user_input, {}).get("internal")
|
||||
lib_name = gpu_lib_choices.get(user_input, {}).get("internal")
|
||||
|
||||
# Write to a file for subsequent runs
|
||||
with open(saved_lib_path.resolve(), "w") as f:
|
||||
f.write(lib_name)
|
||||
print(
|
||||
"Saving your choice to gpu_lib.txt. "
|
||||
"Delete this file and restart if you want to change your selection."
|
||||
)
|
||||
# Write to start options
|
||||
start_options["gpu_lib"] = lib_name
|
||||
print("Saving your choice to start options.")
|
||||
|
||||
# Assume default if the file is invalid
|
||||
if lib_name and lib_name in possible_features:
|
||||
@@ -79,7 +71,7 @@ def get_install_features(lib_name: str = None):
|
||||
print(
|
||||
f"WARN: GPU library {lib_name} not found. "
|
||||
"Skipping GPU-specific dependencies.\n"
|
||||
"WARN: Please delete gpu_lib.txt and restart "
|
||||
"WARN: Please remove the `gpu_lib` key from start_options.json and restart "
|
||||
"if you want to change your selection."
|
||||
)
|
||||
return
|
||||
@@ -105,10 +97,22 @@ def add_start_args(parser: argparse.ArgumentParser):
|
||||
"""Add start script args to the provided parser"""
|
||||
start_group = parser.add_argument_group("start")
|
||||
start_group.add_argument(
|
||||
"-iu",
|
||||
"--ignore-upgrade",
|
||||
"-ur",
|
||||
"--update-repository",
|
||||
action="store_true",
|
||||
help="Ignore requirements upgrade",
|
||||
help="Update local git repository to latest",
|
||||
)
|
||||
start_group.add_argument(
|
||||
"-ud",
|
||||
"--update-deps",
|
||||
action="store_true",
|
||||
help="Update all pip dependencies",
|
||||
)
|
||||
start_group.add_argument(
|
||||
"-fr",
|
||||
"--force-reinstall",
|
||||
action="store_true",
|
||||
help="Forces a reinstall of dependencies. Only works with --update-deps",
|
||||
)
|
||||
start_group.add_argument(
|
||||
"-nw",
|
||||
@@ -123,6 +127,26 @@ def add_start_args(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
def migrate_gpu_lib():
|
||||
gpu_lib_path = pathlib.Path("gpu_lib.txt")
|
||||
|
||||
if not gpu_lib_path.exists():
|
||||
return
|
||||
|
||||
print("Migrating gpu_lib.txt to the new start_options.json")
|
||||
with open("gpu_lib.txt", "r") as gpu_lib_file:
|
||||
start_options["gpu_lib"] = gpu_lib_file.readline().strip()
|
||||
start_options["first_run_done"] = True
|
||||
|
||||
# Remove the old file
|
||||
gpu_lib_path.unlink()
|
||||
|
||||
print(
|
||||
"Successfully migrated gpu lib options to start_options. "
|
||||
"The old file has been deleted."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
subprocess.run(["pip", "-V"])
|
||||
|
||||
@@ -130,6 +154,35 @@ if __name__ == "__main__":
|
||||
parser = init_argparser()
|
||||
add_start_args(parser)
|
||||
args = parser.parse_args()
|
||||
script_ext = "bat" if platform.system() == "Windows" else "sh"
|
||||
|
||||
start_options_path = pathlib.Path("start_options.json")
|
||||
if start_options_path.exists():
|
||||
with open(start_options_path) as start_options_file:
|
||||
start_options = json.load(start_options_file)
|
||||
print("Loaded your saved preferences from `start_options.json`")
|
||||
|
||||
if start_options.get("first_run_done"):
|
||||
first_run = False
|
||||
else:
|
||||
print(
|
||||
"It looks like you're running TabbyAPI for the first time. "
|
||||
"Getting things ready..."
|
||||
)
|
||||
|
||||
# Migrate from old setting storage
|
||||
migrate_gpu_lib()
|
||||
|
||||
# Set variables that rely on start options
|
||||
first_run = not start_options.get("first_run_done")
|
||||
|
||||
if args.gpu_lib:
|
||||
print("Overriding GPU lib name from args.")
|
||||
gpu_lib = args.gpu_lib
|
||||
elif "gpu_lib" in start_options:
|
||||
gpu_lib = start_options.get("gpu_lib")
|
||||
else:
|
||||
gpu_lib = None
|
||||
|
||||
# Create a config if it doesn't exist
|
||||
# This is not necessary to run TabbyAPI, but is new user proof
|
||||
@@ -145,18 +198,65 @@ if __name__ == "__main__":
|
||||
f"Created one at {str(config_path.resolve())}"
|
||||
)
|
||||
|
||||
if args.ignore_upgrade:
|
||||
print("Ignoring pip dependency upgrade due to user request.")
|
||||
else:
|
||||
install_features = None if args.nowheel else get_install_features(args.gpu_lib)
|
||||
features = f"[{install_features}]" if install_features else ""
|
||||
if args.update_repository:
|
||||
print("Pulling latest changes from Github.")
|
||||
pull_command = "git pull"
|
||||
subprocess.run(pull_command.split(" "))
|
||||
|
||||
if first_run or args.update_deps:
|
||||
install_command = ["pip", "install", "-U"]
|
||||
|
||||
# Force a reinstall of the updated dependency if needed
|
||||
if args.force_reinstall:
|
||||
install_command.append("--force-reinstall")
|
||||
|
||||
install_features = None if args.nowheel else get_install_features(gpu_lib)
|
||||
features = f".[{install_features}]" if install_features else "."
|
||||
install_command.append(features)
|
||||
|
||||
# pip install .[features]
|
||||
install_command = f"pip install -U .{features}"
|
||||
print(f"Running install command: {install_command}")
|
||||
subprocess.run(install_command.split(" "))
|
||||
print(f"Running install command: {' '.join(install_command)}")
|
||||
subprocess.run(install_command)
|
||||
print()
|
||||
|
||||
if args.update_deps:
|
||||
print(
|
||||
f"Dependencies updated. Please run TabbyAPI with `start.{script_ext}`. "
|
||||
"Exiting."
|
||||
)
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(
|
||||
f"Dependencies installed. Update them with `update_deps.{script_ext}` "
|
||||
"inside the `update_scripts` folder."
|
||||
)
|
||||
|
||||
if first_run:
|
||||
start_options["first_run_done"] = True
|
||||
|
||||
# Save start options
|
||||
with open("start_options.json", "w") as start_file:
|
||||
start_file.write(json.dumps(start_options))
|
||||
|
||||
print(
|
||||
"Successfully wrote your start script options to `start_options.json`. \n"
|
||||
"If something goes wrong, editing or deleting the file "
|
||||
"will reinstall TabbyAPI as a first-time user."
|
||||
)
|
||||
|
||||
# Import entrypoint after installing all requirements
|
||||
from main import entrypoint
|
||||
try:
|
||||
from main import entrypoint
|
||||
|
||||
asyncio.run(entrypoint(convert_args_to_dict(args, parser)))
|
||||
converted_args = convert_args_to_dict(args, parser)
|
||||
|
||||
print("Starting TabbyAPI...")
|
||||
entrypoint(converted_args)
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
traceback.print_exc()
|
||||
print(
|
||||
"\n"
|
||||
"This error was raised because a package was not found.\n"
|
||||
"Update your dependencies by running update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'}\n\n"
|
||||
)
|
||||
|
||||
24
update_scripts/update_deps.bat
Normal file
24
update_scripts/update_deps.bat
Normal file
@@ -0,0 +1,24 @@
|
||||
@echo off
|
||||
|
||||
:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades
|
||||
:: This is intended for users who want to start the API and have everything upgraded and installed
|
||||
|
||||
:: cd to the parent directory
|
||||
cd "%~dp0.."
|
||||
|
||||
:: Don't create a venv if a conda environment is active
|
||||
if exist "%CONDA_PREFIX%" (
|
||||
echo It looks like you're in a conda environment. Skipping venv check.
|
||||
) else (
|
||||
if not exist "venv\" (
|
||||
echo Venv doesn't exist! Please run start.bat instead.
|
||||
exit 0
|
||||
)
|
||||
|
||||
call .\venv\Scripts\activate.bat
|
||||
)
|
||||
|
||||
:: Call the python script with batch args
|
||||
call python start.py --update-deps %*
|
||||
|
||||
pause
|
||||
19
update_scripts/update_deps.sh
Executable file
19
update_scripts/update_deps.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd "$(dirname "$0")/.." || exit
|
||||
|
||||
if [ -n "$CONDA_PREFIX" ]; then
|
||||
echo "It looks like you're in a conda environment. Skipping venv check."
|
||||
else
|
||||
if [ ! -d "venv" ]; then
|
||||
echo "Venv doesn't exist! Please run start.sh instead."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Activating venv"
|
||||
|
||||
# shellcheck source=/dev/null
|
||||
source venv/bin/activate
|
||||
fi
|
||||
|
||||
python3 start.py --update-deps "$@"
|
||||
24
update_scripts/update_deps_and_pull.bat
Normal file
24
update_scripts/update_deps_and_pull.bat
Normal file
@@ -0,0 +1,24 @@
|
||||
@echo off
|
||||
|
||||
:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades
|
||||
:: This is intended for users who want to start the API and have everything upgraded and installed
|
||||
|
||||
:: cd to the parent directory
|
||||
cd "%~dp0.."
|
||||
|
||||
:: Don't create a venv if a conda environment is active
|
||||
if exist "%CONDA_PREFIX%" (
|
||||
echo It looks like you're in a conda environment. Skipping venv check.
|
||||
) else (
|
||||
if not exist "venv\" (
|
||||
echo Venv doesn't exist! Please run start.bat instead.
|
||||
exit 0
|
||||
)
|
||||
|
||||
call .\venv\Scripts\activate.bat
|
||||
)
|
||||
|
||||
:: Call the python script with batch args
|
||||
call python start.py --update-deps --update-repository %*
|
||||
|
||||
pause
|
||||
19
update_scripts/update_deps_and_pull.sh
Executable file
19
update_scripts/update_deps_and_pull.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd "$(dirname "$0")/.." || exit
|
||||
|
||||
if [ -n "$CONDA_PREFIX" ]; then
|
||||
echo "It looks like you're in a conda environment. Skipping venv check."
|
||||
else
|
||||
if [ ! -d "venv" ]; then
|
||||
echo "Venv doesn't exist! Please run start.sh instead."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Activating venv"
|
||||
|
||||
# shellcheck source=/dev/null
|
||||
source venv/bin/activate
|
||||
fi
|
||||
|
||||
python3 start.py --update-deps --update-repository "$@"
|
||||
Reference in New Issue
Block a user