diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 2bbf60e..0000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,35 +0,0 @@ ---- -name: Bug report -about: Report code related issues -title: "[BUG]" -labels: bug -assignees: '' - ---- - -**Disclaimer:** Github Issues are **only** for code related bugs. If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj) - -**Describe the bug** -A clear and concise description of what the bug is. - -**To Reproduce** -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** -A clear and concise description of what you expected to happen. - -**Logs** -If applicable, add logs and tracebacks to help explain your problem. - -**System info** (Bugs without this information will go lower on our priority list!) - - OS: [ex. Windows] - - Python version: [ex. 3.11] - - CUDA/ROCm version: [ex. 12.x] - - Python version: [ex. 3.11] - -**Additional context** -Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 0000000..c520e24 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,97 @@ +name: Bug report +description: Report code related issues +title: "[BUG]" +labels: bug +body: + +- type: markdown + attributes: + value: | + ### Disclaimer: + Github Issues are **only** for code related bugs. + If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj) + +- type: dropdown + attributes: + label: OS + options: + - Windows + - Linux + validations: + required: true + +- type: dropdown + attributes: + label: GPU Library + description: Ex. CUDA, ROCm + options: + - CUDA 12.x + - CUDA 11.8 + - AMD ROCm + validations: + required: true + +- type: dropdown + attributes: + label: Python version + options: + - '3.12' + - '3.11' + - '3.10' + validations: + required: true + +- type: textarea + attributes: + label: Describe the bug + description: A clear and concise description of what the bug is. + validations: + required: true + +- type: textarea + attributes: + label: Reproduction steps + description: Walk us through how the bug occurred and how to make it happen. + validations: + required: true + +- type: textarea + attributes: + label: Expected behavior + description: What was expected to happen? + validations: + required: true + +- type: textarea + attributes: + label: Logs + description: If applicable, add logs and tracebacks to help explain your problem. + validations: + required: false + +- type: textarea + attributes: + label: Additional context + description: Add any other context about the problem here. + validations: + required: false + +- type: checkboxes + attributes: + label: Acknowledgements + description: Before submitting this issue, please make sure you have completed the following checklist. + options: + - label: I have looked for similar issues before submitting this one. + required: true + - label: I have read the disclaimer, and this issue is related to a code bug. If I have a question, I will use the Discord server. + required: true + - label: I understand that the developers have lives and my issue will be answered when possible. + required: true + - label: I understand the developers of this program are human, and I will ask my questions politely. + required: true + +- type: markdown + attributes: + value: | + ## Thanks! + Well-formatted issues improve TabbyAPI and make the development process smoother. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index e771a8d..0000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -name: Feature request -about: Suggest a new idea -title: "[REQUEST]" -labels: '' -assignees: '' - ---- - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Why should this feature be added?** -An explanation of why the feature should be added. Please be as specific as possible to help us understand the reasoning. - -**Examples** -Examples of the feature in action and its significance compared to not having the feature. - -**Additional context** -Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yaml b/.github/ISSUE_TEMPLATE/feature_request.yaml new file mode 100644 index 0000000..9cf3d8b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yaml @@ -0,0 +1,69 @@ +name: Feature request +description: Suggest a new idea +title: "[REQUEST]" +body: + +- type: textarea + attributes: + label: Problem + description: Is the feature request related to a problem? If so, please describe. + placeholder: A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + validations: + required: false + +- type: textarea + attributes: + label: Solution + description: Describe the solution you'd like. + placeholder: A clear and concise description of what you want to happen. + validations: + required: true + +- type: textarea + attributes: + label: Alternatives + description: What alternative options did you consider? + validations: + required: false + +- type: textarea + attributes: + label: Explanation + description: Why should this feature be added? + validations: + required: true + +- type: textarea + attributes: + label: Examples + description: | + Examples of the feature in action and its significance. + + Not required, but will make your request easier to understand. + validations: + required: false + +- type: textarea + attributes: + label: Additional context + description: Anything else to add? + validations: + required: false + +- type: checkboxes + attributes: + label: Acknowledgements + description: Before submitting this issue, please make sure you have completed the following checklist. + options: + - label: I have looked for similar requests before submitting this one. + required: true + - label: I understand that the developers have lives and my issue will be answered when possible. + required: true + - label: I understand the developers of this program are human, and I will make my requests politely. + required: true + +- type: markdown + attributes: + value: | + ## Thanks! + Well-formatted issues improve TabbyAPI and make the development process smoother. diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index dd59121..a7b3327 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -47,12 +47,17 @@ jobs: run: | npm install @redocly/cli -g - name: Export OpenAPI docs - run: EXPORT_OPENAPI=1 python main.py + run: | + EXPORT_OPENAPI=1 python main.py + mv openapi.json openapi-oai.json + EXPORT_OPENAPI=1 python main.py --api-servers kobold + mv openapi.json openapi-kobold.json - name: Build and store Redocly site run: | - redocly build-docs openapi.json mkdir static - mv redoc-static.html static/index.html + mkdir static/kobold + redocly build-docs openapi-oai.json -o static/index.html + redocly build-docs openapi-kobold.json -o static/kobold/index.html - name: Setup Pages uses: actions/configure-pages@v5 - name: Upload artifact diff --git a/.gitignore b/.gitignore index ebc96c7..2761a6b 100644 --- a/.gitignore +++ b/.gitignore @@ -200,5 +200,11 @@ sampler_overrides/* # Gpu lib preferences file gpu_lib.txt +# Start options file +start_options.json + # OpenAPI JSON openapi.json + +# Infinity-emb cache +.infinity_cache/ diff --git a/README.md b/README.md index 3b1eea1..824826d 100644 --- a/README.md +++ b/README.md @@ -32,20 +32,41 @@ A FastAPI based application that allows for generating text using an LLM (large language model) using the [Exllamav2 backend](https://github.com/turboderp/exllamav2) +TabbyAPI is also the official API backend server for ExllamaV2. + ## Disclaimer -This project is marked rolling release. There may be bugs and changes down the line. Please be aware that you might need to reinstall dependencies if needed. +This project is marked as rolling release. There may be bugs and changes down the line. Please be aware that you might need to reinstall dependencies if needed. -TabbyAPI is a hobby project solely for a small amount of users. It is not meant to run on production servers. For that, please look at other backends that support those workloads. +TabbyAPI is a hobby project made for a small amount of users. It is not meant to run on production servers. For that, please look at other solutions that support those workloads. ## Getting Started > [!IMPORTANT] > -> This README is not for getting started. Please read the Wiki. +> This README does not have instructions for setting up. Please read the Wiki. Read the [Wiki](https://github.com/theroyallab/tabbyAPI/wiki/1.-Getting-Started) for more information. It contains user-facing documentation for installation, configuration, sampling, API usage, and so much more. +## Features + +- OpenAI compatible API +- Loading/unloading models +- HuggingFace model downloading +- Embedding model support +- JSON schema + Regex + EBNF support +- AI Horde support +- Speculative decoding via draft models +- Multi-lora with independent scaling (ex. a weight of 0.9) +- Inbuilt proxy to override client request parameters/samplers +- Flexible Jinja2 template engine for chat completions that conforms to HuggingFace +- Concurrent inference with asyncio +- Utilizes modern python paradigms +- Continuous batching engine using paged attention +- Fast classifer-free guidance + +And much more. If something is missing here, PR it in! + ## Supported Model Types TabbyAPI uses Exllamav2 as a powerful and fast backend for model inference, loading, etc. Therefore, the following types of models are supported: @@ -58,18 +79,6 @@ TabbyAPI uses Exllamav2 as a powerful and fast backend for model inference, load In addition, TabbyAPI supports parallel batching using paged attention for Nvidia Ampere GPUs and higher. -#### Alternative Loaders/Backends - -If you want to use a different model type or quantization method than the ones listed above, here are some alternative backends with their own APIs: - -- GGUF + GGML - [KoboldCPP](https://github.com/lostruins/KoboldCPP) - -- Production ready + Many other quants + batching - [Aphrodite Engine](https://github.com/PygmalionAI/Aphrodite-engine) - -- Production ready + batching - [VLLM](https://github.com/vllm-project/vllm) - -- [Text Generation WebUI](https://github.com/oobabooga/text-generation-webui) - ## Contributing Use the template when creating issues or pull requests, otherwise the developers may not look at your post. @@ -84,6 +93,17 @@ If you have a Pull Request - Describe the pull request in detail, what, and why you are changing something +## Acknowldgements + +TabbyAPI would not exist without the work of other contributors and FOSS projects: + +- [ExllamaV2](https://github.com/turboderp/exllamav2) +- [Aphrodite Engine](https://github.com/PygmalionAI/Aphrodite-engine) +- [infinity-emb](https://github.com/michaelfeil/infinity) +- [FastAPI](https://github.com/fastapi/fastapi) +- [Text Generation WebUI](https://github.com/oobabooga/text-generation-webui) +- [SillyTavern](https://github.com/SillyTavern/SillyTavern) + ## Developers and Permissions Creators/Developers: diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5f4e86b..98b5636 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -47,7 +47,7 @@ from common.templating import ( TemplateLoadError, find_template_from_model, ) -from common.transformers_utils import GenerationConfig +from common.transformers_utils import GenerationConfig, HuggingFaceConfig from common.utils import coalesce, unwrap @@ -70,8 +70,9 @@ class ExllamaV2Container: cache_size: int = None cache_mode: str = "FP16" draft_cache_mode: str = "FP16" - max_batch_size: int = 20 + max_batch_size: Optional[int] = None generation_config: Optional[GenerationConfig] = None + hf_config: Optional[HuggingFaceConfig] = None # GPU split vars gpu_split: Optional[list] = None @@ -186,6 +187,9 @@ class ExllamaV2Container: except AttributeError: pass + # Create the hf_config + self.hf_config = HuggingFaceConfig.from_file(model_directory) + # Then override the base_seq_len if present override_base_seq_len = kwargs.get("override_base_seq_len") if override_base_seq_len: @@ -213,6 +217,9 @@ class ExllamaV2Container: # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) + # Set max batch size to the config override + self.max_batch_size = unwrap(kwargs.get("max_batch_size")) + # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention if ( @@ -268,15 +275,8 @@ class ExllamaV2Container: else: self.cache_size = self.config.max_seq_len - # Try to set prompt template - self.prompt_template = self.find_prompt_template( - kwargs.get("prompt_template"), model_directory - ) - # Load generation config overrides - generation_config_path = ( - pathlib.Path(self.config.model_dir) / "generation_config.json" - ) + generation_config_path = model_directory / "generation_config.json" if generation_config_path.exists(): try: self.generation_config = GenerationConfig.from_file( @@ -288,6 +288,11 @@ class ExllamaV2Container: "Skipping generation config load because of an unexpected error." ) + # Try to set prompt template + self.prompt_template = self.find_prompt_template( + kwargs.get("prompt_template"), model_directory + ) + # Catch all for template lookup errors if self.prompt_template: logger.info( @@ -405,6 +410,9 @@ class ExllamaV2Container: def get_model_path(self, is_draft: bool = False): """Get the path for this model.""" + if is_draft and not self.draft_config: + return None + model_path = pathlib.Path( self.draft_config.model_dir if is_draft else self.config.model_dir ) @@ -446,6 +454,11 @@ class ExllamaV2Container: # Immediately abort all jobs if asked if skip_wait: + logger.warning( + "Immediately terminating all jobs. " + "Clients will have their requests cancelled.\n" + ) + # Requires a copy to avoid errors during iteration jobs_copy = self.generator.jobs.copy() for job in jobs_copy.values(): @@ -485,15 +498,7 @@ class ExllamaV2Container: yield value # Create async generator - self.generator = ExLlamaV2DynamicGeneratorAsync( - model=self.model, - cache=self.cache, - draft_model=self.draft_model, - draft_cache=self.draft_cache, - tokenizer=self.tokenizer, - max_batch_size=self.max_batch_size, - paged=self.paged, - ) + await self.create_generator() # Clean up any extra vram usage from torch and cuda # (Helps reduce VRAM bottlenecking on Windows) @@ -642,6 +647,34 @@ class ExllamaV2Container: input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) + async def create_generator(self): + try: + # Don't acquire locks unless a model is loaded + if self.model_loaded: + await self.load_lock.acquire() + + # Immediately cancel all jobs + await self.wait_for_jobs(skip_wait=True) + + # Create new generator + self.generator = ExLlamaV2DynamicGeneratorAsync( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=self.max_batch_size, + paged=self.paged, + ) + finally: + # This means the generator is being recreated + # The load lock is already released in the load function + if self.model_loaded: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + def get_loras(self): """Convenience function to get all loras.""" @@ -701,11 +734,15 @@ class ExllamaV2Container: Free all VRAM resources used by this model """ - try: - await self.load_lock.acquire() + # Shutdown immediately unloads and bypasses all locks + do_shutdown = kwargs.get("shutdown") - # Wait for other jobs to finish - await self.wait_for_jobs(kwargs.get("skip_wait")) + try: + if not do_shutdown: + await self.load_lock.acquire() + + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) # Delete references held in the grammar module clear_grammar_func_cache() @@ -745,10 +782,11 @@ class ExllamaV2Container: logger.info("Loras unloaded." if loras_only else "Model unloaded.") finally: - self.load_lock.release() + if not do_shutdown: + self.load_lock.release() - async with self.load_condition: - self.load_condition.notify_all() + async with self.load_condition: + self.load_condition.notify_all() def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string""" @@ -800,10 +838,14 @@ class ExllamaV2Container: return dict(zip_longest(top_tokens, cleaned_values)) - async def generate(self, prompt: str, **kwargs): + async def generate( + self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs + ): """Generate a response to a prompt""" generations = [] - async for generation in self.generate_gen(prompt, **kwargs): + async for generation in self.generate_gen( + prompt, request_id, abort_event, **kwargs + ): generations.append(generation) joined_generation = { @@ -853,7 +895,11 @@ class ExllamaV2Container: return kwargs async def generate_gen( - self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs + self, + prompt: str, + request_id: str, + abort_event: Optional[asyncio.Event] = None, + **kwargs, ): """ Create generator function for prompt completion. @@ -972,8 +1018,37 @@ class ExllamaV2Container: kwargs.get("repetition_decay"), fallback_decay, 0 ) - stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + # Initialize grammar handler + grammar_handler = ExLlamaV2Grammar() + + # Add JSON schema filter if it exists + json_schema = unwrap(kwargs.get("json_schema")) + if json_schema: + grammar_handler.add_json_schema_filter( + json_schema, self.model, self.tokenizer + ) + + # Add regex filter if it exists + regex_pattern = unwrap(kwargs.get("regex_pattern")) + if regex_pattern: + grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) + + # Add EBNF filter if it exists + grammar_string = unwrap(kwargs.get("grammar_string")) + if grammar_string: + grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) + + # Set banned strings banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) + if banned_strings and len(grammar_handler.filters) > 0: + logger.warning( + "Disabling banned_strings because " + "they cannot be used with grammar filters." + ) + + banned_strings = [] + + stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) logit_bias = kwargs.get("logit_bias") @@ -1021,26 +1096,6 @@ class ExllamaV2Container: "in the model's vocab. Skipping." ) - # Initialize grammar handler - grammar_handler = ExLlamaV2Grammar() - - # Add JSON schema filter if it exists - json_schema = unwrap(kwargs.get("json_schema")) - if json_schema: - grammar_handler.add_json_schema_filter( - json_schema, self.model, self.tokenizer - ) - - # Add regex filter if it exists - regex_pattern = unwrap(kwargs.get("regex_pattern")) - if regex_pattern: - grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) - - # Add EBNF filter if it exists - grammar_string = unwrap(kwargs.get("grammar_string")) - if grammar_string: - grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) - # Fetch EOS tokens from generation_config if they exist eos_tokens = ( self.generation_config.eos_tokens() @@ -1085,34 +1140,11 @@ class ExllamaV2Container: # This is an inverse of skip_special_tokens decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) - # Log generation options to console - # Some options are too large, so log the args instead - log_generation_params( - max_tokens=max_tokens, - min_tokens=min_tokens, - stream=kwargs.get("stream"), - **gen_settings_log_dict, - token_healing=token_healing, - auto_scale_penalty_range=auto_scale_penalty_range, - generate_window=generate_window, - bos_token_id=self.tokenizer.bos_token_id, - eos_token_id=eos_tokens, - add_bos_token=add_bos_token, - ban_eos_token=ban_eos_token, - skip_special_tokens=not decode_special_tokens, - speculative_ngram=self.generator.speculative_ngram, - logprobs=request_logprobs, - stop_conditions=stop_conditions, - banned_tokens=banned_tokens, - banned_strings=banned_strings, - logit_bias=logit_bias, - filters=grammar_handler.filters, - ) - # Log prompt to console - log_prompt(prompt, negative_prompt) + log_prompt(prompt, request_id, negative_prompt) # Create and add a new job + # Don't use the request ID here as there can be multiple jobs per request job_id = uuid.uuid4().hex job = ExLlamaV2DynamicJobAsync( self.generator, @@ -1138,6 +1170,7 @@ class ExllamaV2Container: max_seq_len = self.config.max_seq_len generated_tokens = 0 full_response = "" + metrics_result = {} # Get the generation status once it's ready try: @@ -1191,23 +1224,15 @@ class ExllamaV2Container: # Second yield if eos is true if result.get("eos"): - log_response(full_response) + log_response(request_id, full_response) eos_reason = result.get("eos_reason") finish_reason = ( "length" if eos_reason == "max_new_tokens" else "stop" ) - log_metrics( - result.get("time_enqueued"), - result.get("prompt_tokens"), - result.get("cached_tokens"), - result.get("time_prefill"), - result.get("new_tokens"), - result.get("time_generate"), - context_len, - max_seq_len, - ) + # Save the final result for metrics logging + metrics_result = result # Remove the token text generation = { @@ -1220,3 +1245,53 @@ class ExllamaV2Container: break except asyncio.CancelledError: await job.cancel() + except Exception as ex: + # Create a new generator since the current state is broken + # No need to wait for this to finish + logger.error( + "FATAL ERROR with generation. " + "Attempting to recreate the generator. " + "If this fails, please restart the server.\n" + ) + asyncio.ensure_future(self.create_generator()) + + raise ex + finally: + # Log generation options to console + # Some options are too large, so log the args instead + log_generation_params( + request_id=request_id, + max_tokens=max_tokens, + min_tokens=min_tokens, + stream=kwargs.get("stream"), + **gen_settings_log_dict, + token_healing=token_healing, + auto_scale_penalty_range=auto_scale_penalty_range, + generate_window=generate_window, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=eos_tokens, + add_bos_token=add_bos_token, + ban_eos_token=ban_eos_token, + skip_special_tokens=not decode_special_tokens, + speculative_ngram=self.generator.speculative_ngram, + logprobs=request_logprobs, + stop_conditions=stop_conditions, + banned_tokens=banned_tokens, + banned_strings=banned_strings, + logit_bias=logit_bias, + filters=grammar_handler.filters, + ) + + # Log the metrics if present + if metrics_result: + log_metrics( + request_id, + metrics_result.get("time_enqueued"), + metrics_result.get("prompt_tokens"), + metrics_result.get("cached_tokens"), + metrics_result.get("time_prefill"), + metrics_result.get("new_tokens"), + metrics_result.get("time_generate"), + context_len, + max_seq_len, + ) diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 5b1d567..cf73be4 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -1,20 +1,22 @@ +import platform +import torch from packaging import version from importlib.metadata import PackageNotFoundError, version as package_version from loguru import logger -import torch def check_exllama_version(): """Verifies the exllama version""" - required_version = version.parse("0.1.6") + required_version = version.parse("0.1.7") current_version = version.parse(package_version("exllamav2").split("+")[0]) unsupported_message = ( f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " f"or greater. Your current version is {current_version}.\n" - "Please upgrade your environment by running a start script " - "(start.bat or start.sh)\n\n" + "Please update your environment by running an update script " + "(update_scripts/" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" "Or you can manually run a requirements update " "using the following command:\n\n" "For CUDA 12.1:\n" @@ -71,8 +73,9 @@ def supports_paged_attn(): "Switching to compatibility mode. \n" "This disables parallel batching " "and features that rely on it (ex. CFG). \n" - "Please upgrade your environment by running a start script " - "(start.bat or start.sh)\n\n" + "Please upgrade your environment by running an update script " + "(update_scripts/" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" "Or you can manually run a requirements update " "using the following command:\n\n" "For CUDA 12.1:\n" diff --git a/backends/infinity/model.py b/backends/infinity/model.py new file mode 100644 index 0000000..35a4df4 --- /dev/null +++ b/backends/infinity/model.py @@ -0,0 +1,66 @@ +import gc +import pathlib +import torch +from loguru import logger +from typing import List, Optional + +from common.utils import unwrap + +# Conditionally import infinity to sidestep its logger +# TODO: Make this prettier +try: + from infinity_emb import EngineArgs, AsyncEmbeddingEngine + + has_infinity_emb = True +except ImportError: + has_infinity_emb = False + + +class InfinityContainer: + model_dir: pathlib.Path + model_is_loading: bool = False + model_loaded: bool = False + + # Conditionally set the type hint based on importablity + # TODO: Clean this up + if has_infinity_emb: + engine: Optional[AsyncEmbeddingEngine] = None + else: + engine = None + + def __init__(self, model_directory: pathlib.Path): + self.model_dir = model_directory + + async def load(self, **kwargs): + self.model_is_loading = True + + # Use cpu by default + device = unwrap(kwargs.get("embeddings_device"), "cpu") + + engine_args = EngineArgs( + model_name_or_path=str(self.model_dir), + engine="torch", + device=device, + bettertransformer=False, + model_warmup=False, + ) + + self.engine = AsyncEmbeddingEngine.from_args(engine_args) + await self.engine.astart() + + self.model_loaded = True + logger.info("Embedding model successfully loaded.") + + async def unload(self): + await self.engine.astop() + self.engine = None + + gc.collect() + torch.cuda.empty_cache() + + logger.info("Embedding model unloaded.") + + async def generate(self, sentence_input: List[str]): + result_embeddings, usage = await self.engine.embed(sentence_input) + + return {"embeddings": result_embeddings, "usage": usage} diff --git a/common/args.py b/common/args.py index 14508a7..a0f19c2 100644 --- a/common/args.py +++ b/common/args.py @@ -17,13 +17,16 @@ def init_argparser(): """Creates an argument parser that any function can use""" parser = argparse.ArgumentParser( - epilog="These args are only for a subset of the config. " - + "Please edit config.yml for all options!" + epilog="NOTE: These args serve to override parts of the config. " + + "It's highly recommended to edit config.yml for all options and " + + "better descriptions!" ) add_network_args(parser) add_model_args(parser) + add_embeddings_args(parser) add_logging_args(parser) add_developer_args(parser) + add_sampling_args(parser) add_config_args(parser) return parser @@ -64,6 +67,17 @@ def add_network_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Disable HTTP token authenticaion with requests", ) + network_group.add_argument( + "--send-tracebacks", + type=str_to_bool, + help="Decide whether to send error tracebacks over the API", + ) + network_group.add_argument( + "--api-servers", + type=str, + nargs="+", + help="API servers to enable. Options: (OAI, Kobold)", + ) def add_model_args(parser: argparse.ArgumentParser): @@ -74,6 +88,17 @@ def add_model_args(parser: argparse.ArgumentParser): "--model-dir", type=str, help="Overrides the directory to look for models" ) model_group.add_argument("--model-name", type=str, help="An initial model to load") + model_group.add_argument( + "--use-dummy-models", + type=str_to_bool, + help="Add dummy OAI model names for API queries", + ) + model_group.add_argument( + "--use-as-default", + type=str, + nargs="+", + help="Names of args to use as a default fallback for API load requests ", + ) model_group.add_argument( "--max-seq-len", type=int, help="Override the maximum model sequence length" ) @@ -82,25 +107,17 @@ def add_model_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Overrides base model context length", ) - model_group.add_argument( - "--cache-size", - type=int, - help="The size of the prompt cache (in number of tokens) to allocate", - ) - model_group.add_argument( - "--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb" - ) - model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK") - model_group.add_argument( - "--prompt-template", - type=str, - help="Set the prompt template for chat completions", - ) model_group.add_argument( "--gpu-split-auto", type=str_to_bool, help="Automatically allocate resources to GPUs", ) + model_group.add_argument( + "--autosplit-reserve", + type=int, + nargs="+", + help="Reserve VRAM used for autosplit loading (in MBs) ", + ) model_group.add_argument( "--gpu-split", type=float, @@ -108,15 +125,44 @@ def add_model_args(parser: argparse.ArgumentParser): help="An integer array of GBs of vram to split between GPUs. " + "Ignored if gpu_split_auto is true", ) + model_group.add_argument( + "--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb" + ) + model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK") + model_group.add_argument( + "--cache-mode", + type=str, + help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)", + ) + model_group.add_argument( + "--cache-size", + type=int, + help="The size of the prompt cache (in number of tokens) to allocate", + ) + model_group.add_argument( + "--chunk-size", + type=int, + help="Chunk size for prompt ingestion", + ) + model_group.add_argument( + "--max-batch-size", + type=int, + help="Maximum amount of prompts to process at one time", + ) + model_group.add_argument( + "--prompt-template", + type=str, + help="Set the jinja2 prompt template for chat completions", + ) model_group.add_argument( "--num-experts-per-token", type=int, help="Number of experts to use per token in MoE models", ) model_group.add_argument( - "--use-cfg", + "--fasttensors", type=str_to_bool, - help="Enables CFG support", + help="Possibly increases model loading speeds", ) @@ -132,6 +178,11 @@ def add_logging_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Enable generation parameter logging", ) + logging_group.add_argument( + "--log-requests", + type=str_to_bool, + help="Enable request logging", + ) def add_developer_args(parser: argparse.ArgumentParser): @@ -149,5 +200,38 @@ def add_developer_args(parser: argparse.ArgumentParser): developer_group.add_argument( "--cuda-malloc-backend", type=str_to_bool, - help="Disables API request streaming", + help="Runs with the pytorch CUDA malloc backend", + ) + developer_group.add_argument( + "--uvloop", + type=str_to_bool, + help="Run asyncio using Uvloop or Winloop", + ) + + +def add_sampling_args(parser: argparse.ArgumentParser): + """Adds sampling-specific arguments""" + + sampling_group = parser.add_argument_group("sampling") + sampling_group.add_argument( + "--override-preset", type=str, help="Select a sampler override preset" + ) + + +def add_embeddings_args(parser: argparse.ArgumentParser): + """Adds arguments specific to embeddings""" + + embeddings_group = parser.add_argument_group("embeddings") + embeddings_group.add_argument( + "--embedding-model-dir", + type=str, + help="Overrides the directory to look for models", + ) + embeddings_group.add_argument( + "--embedding-model-name", type=str, help="An initial model to load" + ) + embeddings_group.add_argument( + "--embeddings-device", + type=str, + help="Device to use for embeddings. Options: (cpu, auto, cuda)", ) diff --git a/common/auth.py b/common/auth.py index fa53262..174208d 100644 --- a/common/auth.py +++ b/common/auth.py @@ -5,11 +5,13 @@ application, it should be fine. import secrets import yaml -from fastapi import Header, HTTPException +from fastapi import Header, HTTPException, Request from pydantic import BaseModel from loguru import logger from typing import Optional +from common.utils import coalesce + class AuthKeys(BaseModel): """ @@ -75,7 +77,27 @@ def load_auth_keys(disable_from_config: bool): ) -async def validate_key_permission(test_key: str): +def get_key_permission(request: Request): + """ + Gets the key permission from a request. + + Internal only! Use the depends functions for incoming requests. + """ + + # Give full admin permissions if auth is disabled + if DISABLE_AUTH: + return "admin" + + # Hyphens are okay here + test_key = coalesce( + request.headers.get("x-admin-key"), + request.headers.get("x-api-key"), + request.headers.get("authorization"), + ) + + if test_key is None: + raise ValueError("The provided authentication key is missing.") + if test_key.lower().startswith("bearer"): test_key = test_key.split(" ")[1] diff --git a/common/config.py b/common/config.py index 86aedac..9b2f654 100644 --- a/common/config.py +++ b/common/config.py @@ -46,15 +46,12 @@ def from_args(args: dict): GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override} # Generation Logging config - gen_logging_override = args.get("logging") - if gen_logging_override: - cur_gen_logging_config = gen_logging_config() + logging_override = args.get("logging") + if logging_override: + cur_logging_config = logging_config() GLOBAL_CONFIG["logging"] = { - **cur_gen_logging_config, - **{ - k.replace("log_", ""): gen_logging_override[k] - for k in gen_logging_override - }, + **cur_logging_config, + **{k.replace("log_", ""): logging_override[k] for k in logging_override}, } developer_override = args.get("developer") @@ -62,6 +59,11 @@ def from_args(args: dict): cur_developer_config = developer_config() GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override} + embeddings_override = args.get("embeddings") + if embeddings_override: + cur_embeddings_config = embeddings_config() + GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override} + def sampling_config(): """Returns the sampling parameter config from the global config""" @@ -90,11 +92,16 @@ def network_config(): return unwrap(GLOBAL_CONFIG.get("network"), {}) -def gen_logging_config(): - """Returns the generation logging config from the global config""" +def logging_config(): + """Returns the logging config from the global config""" return unwrap(GLOBAL_CONFIG.get("logging"), {}) def developer_config(): """Returns the developer specific config from the global config""" return unwrap(GLOBAL_CONFIG.get("developer"), {}) + + +def embeddings_config(): + """Returns the embeddings config from the global config""" + return unwrap(GLOBAL_CONFIG.get("embeddings"), {}) diff --git a/common/downloader.py b/common/downloader.py index a0b16cb..b9e1b72 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -101,6 +101,7 @@ async def hf_repo_download( chunk_limit: Optional[float], include: Optional[List[str]], exclude: Optional[List[str]], + timeout: Optional[int], repo_type: Optional[str] = "model", ): """Gets a repo's information from HuggingFace and downloads it locally.""" @@ -145,7 +146,8 @@ async def hf_repo_download( logger.info(f"Saving {repo_id} to {str(download_path)}") try: - async with aiohttp.ClientSession() as session: + client_timeout = aiohttp.ClientTimeout(total=timeout) # Turn off timeout + async with aiohttp.ClientSession(timeout=client_timeout) as session: tasks = [] logger.info(f"Starting download for {repo_id}") diff --git a/common/gen_logging.py b/common/gen_logging.py index fbf10f6..9995818 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -51,25 +51,31 @@ def log_generation_params(**kwargs): logger.info(f"Generation options: {kwargs}\n") -def log_prompt(prompt: str, negative_prompt: Optional[str]): +def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]): """Logs the prompt to console.""" if PREFERENCES.prompt: formatted_prompt = "\n" + prompt - logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n") + logger.info( + f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n" + ) if negative_prompt: formatted_negative_prompt = "\n" + negative_prompt logger.info(f"Negative Prompt: {formatted_negative_prompt}\n") -def log_response(response: str): +def log_response(request_id: str, response: str): """Logs the response to console.""" if PREFERENCES.prompt: formatted_response = "\n" + response - logger.info(f"Response: {formatted_response if response else 'Empty'}\n") + logger.info( + f"Response (ID: {request_id}): " + f"{formatted_response if response else 'Empty'}\n" + ) def log_metrics( + request_id: str, queue_time: float, prompt_tokens: int, cached_tokens: int, @@ -80,7 +86,7 @@ def log_metrics( max_seq_len: int, ): initial_response = ( - f"Metrics: {generated_tokens} tokens generated in " + f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in " f"{round(queue_time + prompt_time + generate_time, 2)} seconds" ) itemization = [] diff --git a/common/model.py b/common/model.py index b925f15..0bfbab2 100644 --- a/common/model.py +++ b/common/model.py @@ -5,11 +5,14 @@ Containers exist as a common interface for backends. """ import pathlib +from enum import Enum +from fastapi import HTTPException from loguru import logger from typing import Optional from common import config from common.logger import get_loading_progress_bar +from common.networking import handle_request_error from common.utils import unwrap from endpoints.utils import do_export_openapi @@ -18,6 +21,21 @@ if not do_export_openapi: # Global model container container: Optional[ExllamaV2Container] = None + embeddings_container = None + + # Type hint the infinity emb container if it exists + from backends.infinity.model import has_infinity_emb + + if has_infinity_emb: + from backends.infinity.model import InfinityContainer + + embeddings_container: Optional[InfinityContainer] = None + + +class ModelType(Enum): + MODEL = "model" + DRAFT = "draft" + EMBEDDING = "embedding" def load_progress(module, modules): @@ -25,11 +43,11 @@ def load_progress(module, modules): yield module, modules -async def unload_model(skip_wait: bool = False): +async def unload_model(skip_wait: bool = False, shutdown: bool = False): """Unloads a model""" global container - await container.unload(skip_wait=skip_wait) + await container.unload(skip_wait=skip_wait, shutdown=shutdown) container = None @@ -46,8 +64,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): f'Model "{loaded_model_name}" is already loaded! Aborting.' ) - # Unload the existing model - if container and container.model: logger.info("Unloading existing model.") await unload_model() @@ -98,17 +114,89 @@ async def unload_loras(): await container.unload(loras_only=True) -def get_config_default(key, fallback=None, is_draft=False): +async def load_embedding_model(model_path: pathlib.Path, **kwargs): + global embeddings_container + + # Break out if infinity isn't installed + if not has_infinity_emb: + raise ImportError( + "Skipping embeddings because infinity-emb is not installed.\n" + "Please run the following command in your environment " + "to install extra packages:\n" + "pip install -U .[extras]" + ) + + # Check if the model is already loaded + if embeddings_container and embeddings_container.engine: + loaded_model_name = embeddings_container.model_dir.name + + if loaded_model_name == model_path.name and embeddings_container.model_loaded: + raise ValueError( + f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.' + ) + + logger.info("Unloading existing embeddings model.") + await unload_embedding_model() + + embeddings_container = InfinityContainer(model_path) + await embeddings_container.load(**kwargs) + + +async def unload_embedding_model(): + global embeddings_container + + await embeddings_container.unload() + embeddings_container = None + + +def get_config_default(key: str, fallback=None, model_type: str = "model"): """Fetches a default value from model config if allowed by the user.""" model_config = config.model_config() default_keys = unwrap(model_config.get("use_as_default"), []) + + # Add extra keys to defaults + default_keys.append("embeddings_device") + if key in default_keys: # Is this a draft model load parameter? - if is_draft: + if model_type == "draft": draft_config = config.draft_model_config() return unwrap(draft_config.get(key), fallback) + elif model_type == "embedding": + embeddings_config = config.embeddings_config() + return unwrap(embeddings_config.get(key), fallback) else: return unwrap(model_config.get(key), fallback) else: return fallback + + +async def check_model_container(): + """FastAPI depends that checks if a model isn't loaded or currently loading.""" + + if container is None or not (container.model_is_loading or container.model_loaded): + error_message = handle_request_error( + "No models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + +async def check_embeddings_container(): + """ + FastAPI depends that checks if an embeddings model is loaded. + + This is the same as the model container check, but with embeddings instead. + """ + + if embeddings_container is None or not ( + embeddings_container.model_is_loading or embeddings_container.model_loaded + ): + error_message = handle_request_error( + "No embedding models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) diff --git a/common/networking.py b/common/networking.py index 47ebe06..7c088a9 100644 --- a/common/networking.py +++ b/common/networking.py @@ -1,12 +1,17 @@ """Common utility functions""" import asyncio +import json import socket import traceback -from fastapi import Request +from fastapi import Depends, HTTPException, Request from loguru import logger from pydantic import BaseModel from typing import Optional +from uuid import uuid4 + +from common import config +from common.utils import unwrap class TabbyRequestErrorMessage(BaseModel): @@ -33,15 +38,18 @@ def get_generator_error(message: str, exc_info: bool = True): def handle_request_error(message: str, exc_info: bool = True): """Log a request error to the console.""" + trace = traceback.format_exc() + send_trace = unwrap(config.network_config().get("send_tracebacks"), False) + error_message = TabbyRequestErrorMessage( - message=message, trace=traceback.format_exc() + message=message, trace=trace if send_trace else None ) request_error = TabbyRequestError(error=error_message) # Log the error and provided message to the console - if error_message.trace and exc_info: - logger.error(error_message.trace) + if trace and exc_info: + logger.error(trace) logger.error(f"Sent to request: {message}") @@ -78,8 +86,9 @@ async def run_with_request_disconnect( try: return call_task.result() - except (asyncio.CancelledError, asyncio.InvalidStateError): + except (asyncio.CancelledError, asyncio.InvalidStateError) as ex: handle_request_disconnect(disconnect_message) + raise HTTPException(422, disconnect_message) from ex def is_port_in_use(port: int) -> bool: @@ -93,3 +102,39 @@ def is_port_in_use(port: int) -> bool: test_socket.settimeout(1) with test_socket: return test_socket.connect_ex(("localhost", port)) == 0 + + +async def add_request_id(request: Request): + """FastAPI depends to add a UUID to a request's state.""" + + request.state.id = uuid4().hex + return request + + +async def log_request(request: Request): + """FastAPI depends to log a request to the user.""" + + log_message = [f"Information for {request.method} request {request.state.id}:"] + + log_message.append(f"URL: {request.url}") + log_message.append(f"Headers: {dict(request.headers)}") + + if request.method != "GET": + body_bytes = await request.body() + if body_bytes: + body = json.loads(body_bytes.decode("utf-8")) + + log_message.append(f"Body: {dict(body)}") + + logger.info("\n".join(log_message)) + + +def get_global_depends(): + """Returns global dependencies for a FastAPI app.""" + + depends = [Depends(add_request_id)] + + if config.logging_config().get("requests"): + depends.append(Depends(log_request)) + + return depends diff --git a/common/sampling.py b/common/sampling.py index bbeddb8..72552ce 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -16,11 +16,15 @@ class BaseSamplerRequest(BaseModel): max_tokens: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("max_tokens"), + validation_alias=AliasChoices("max_tokens", "max_length"), + description="Aliases: max_length", examples=[150], ) min_tokens: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("min_tokens", 0), + validation_alias=AliasChoices("min_tokens", "min_length"), + description="Aliases: min_length", examples=[0], ) @@ -30,13 +34,22 @@ class BaseSamplerRequest(BaseModel): ) stop: Optional[Union[str, List[str]]] = Field( - default_factory=lambda: get_default_sampler_value("stop", []) + default_factory=lambda: get_default_sampler_value("stop", []), + validation_alias=AliasChoices("stop", "stop_sequence"), + description="Aliases: stop_sequence", ) banned_strings: Optional[Union[str, List[str]]] = Field( default_factory=lambda: get_default_sampler_value("banned_strings", []) ) + banned_tokens: Optional[Union[List[int], str]] = Field( + default_factory=lambda: get_default_sampler_value("banned_tokens", []), + validation_alias=AliasChoices("banned_tokens", "custom_token_bans"), + description="Aliases: custom_token_bans", + examples=[[128, 330]], + ) + token_healing: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("token_healing", False) ) @@ -76,6 +89,13 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) + typical: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("typical", 1.0), + validation_alias=AliasChoices("typical", "typical_p"), + description="Aliases: typical_p", + examples=[1.0], + ) + skew: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("skew", 0.0), examples=[0.0], @@ -91,9 +111,24 @@ class BaseSamplerRequest(BaseModel): repetition_penalty: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + validation_alias=AliasChoices("repetition_penalty", "rep_pen"), + description="Aliases: rep_pen", examples=[1.0], ) + penalty_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("penalty_range", -1), + validation_alias=AliasChoices( + "penalty_range", + "repetition_range", + "repetition_penalty_range", + "rep_pen_range", + ), + description=( + "Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range" + ), + ) + repetition_decay: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("repetition_decay", 0) ) @@ -118,6 +153,8 @@ class BaseSamplerRequest(BaseModel): ban_eos_token: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("ban_eos_token", False), + validation_alias=AliasChoices("ban_eos_token", "ignore_eos"), + description="Aliases: ignore_eos", examples=[False], ) @@ -151,24 +188,6 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("speculative_ngram"), ) - # Aliased variables - typical: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("typical", 1.0), - validation_alias=AliasChoices("typical", "typical_p"), - description="Aliases: typical_p", - examples=[1.0], - ) - - penalty_range: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("penalty_range", -1), - validation_alias=AliasChoices( - "penalty_range", - "repetition_range", - "repetition_penalty_range", - ), - description="Aliases: repetition_range, repetition_penalty_range", - ) - cfg_scale: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), validation_alias=AliasChoices("cfg_scale", "guidance_scale"), @@ -196,13 +215,6 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) - banned_tokens: Optional[Union[List[int], str]] = Field( - default_factory=lambda: get_default_sampler_value("banned_tokens", []), - validation_alias=AliasChoices("banned_tokens", "custom_token_bans"), - description="Aliases: custom_token_bans", - examples=[[128, 330]], - ) - # TODO: Return back to adaptable class-based validation But that's just too much # abstraction compared to simple if statements at the moment def validate_params(self): diff --git a/common/signals.py b/common/signals.py index 07d7564..97f595b 100644 --- a/common/signals.py +++ b/common/signals.py @@ -1,13 +1,40 @@ +import asyncio import signal import sys from loguru import logger from types import FrameType +from common import model + + +SHUTTING_DOWN: bool = False + def signal_handler(*_): """Signal handler for main function. Run before uvicorn starts.""" + global SHUTTING_DOWN + + if SHUTTING_DOWN: + return + logger.warning("Shutdown signal called. Exiting gracefully.") + SHUTTING_DOWN = True + + # Run async unloads for model + asyncio.ensure_future(signal_handler_async()) + + +async def signal_handler_async(*_): + """Internal signal handler. Runs all async code to shut down the program.""" + + if model.container: + await model.unload_model(skip_wait=True, shutdown=True) + + if model.embeddings_container: + await model.unload_embedding_model() + + # Exit the program sys.exit(0) diff --git a/common/templating.py b/common/templating.py index f742386..7a59946 100644 --- a/common/templating.py +++ b/common/templating.py @@ -51,7 +51,7 @@ class PromptTemplate: raise ImportError( "Parsing these chat completion messages requires jinja2 3.0.0 " f"or greater. Current version: {package_version('jinja2')}\n" - "Please upgrade jinja by running the following command: " + "Please update jinja by running the following command: " "pip install --upgrade jinja2" ) diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 62d4622..9db8ad2 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,6 +1,7 @@ import json import pathlib from typing import List, Optional, Union +from loguru import logger from pydantic import BaseModel @@ -11,6 +12,7 @@ class GenerationConfig(BaseModel): """ eos_token_id: Optional[Union[int, List[int]]] = None + bad_words_ids: Optional[List[List[int]]] = None @classmethod def from_file(self, model_directory: pathlib.Path): @@ -30,3 +32,38 @@ class GenerationConfig(BaseModel): return [self.eos_token_id] else: return self.eos_token_id + + +class HuggingFaceConfig(BaseModel): + """ + An abridged version of HuggingFace's model config. + Will be expanded as needed. + """ + + badwordsids: Optional[str] = None + + @classmethod + def from_file(self, model_directory: pathlib.Path): + """Create an instance from a generation config file.""" + + hf_config_path = model_directory / "config.json" + with open(hf_config_path, "r", encoding="utf8") as hf_config_json: + hf_config_dict = json.load(hf_config_json) + return self.model_validate(hf_config_dict) + + def get_badwordsids(self): + """Wrapper method to fetch badwordsids.""" + + if self.badwordsids: + try: + bad_words_list = json.loads(self.badwordsids) + return bad_words_list + except json.JSONDecodeError: + logger.warning( + "Skipping badwordsids from config.json " + "since it's not a valid array." + ) + + return [] + else: + return [] diff --git a/common/utils.py b/common/utils.py index 6787f39..b120022 100644 --- a/common/utils.py +++ b/common/utils.py @@ -18,3 +18,9 @@ def prune_dict(input_dict): """Trim out instances of None from a dictionary.""" return {k: v for k, v in input_dict.items() if v is not None} + + +def flat_map(input_list): + """Flattens a list of lists into a single list.""" + + return [item for sublist in input_list for item in sublist] diff --git a/config_sample.yml b/config_sample.yml index 0e4b180..018ff61 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -19,6 +19,14 @@ network: # Turn on this option if you are ONLY connecting from localhost disable_auth: False + # Send tracebacks over the API to clients (default: False) + # NOTE: Only enable this for debug purposes + send_tracebacks: False + + # Select API servers to enable (default: ["OAI"]) + # Possible values: OAI + api_servers: ["OAI"] + # Options for logging logging: # Enable prompt logging (default: False) @@ -27,6 +35,10 @@ logging: # Enable generation parameter logging (default: False) generation_params: False + # Enable request logging (default: False) + # NOTE: Only use this for debugging! + requests: False + # Options for sampling sampling: # Override preset name. Find this in the sampler-overrides folder (default: None) @@ -50,6 +62,16 @@ developer: # This can save a few MBs of VRAM, but has a risk of errors. Use at your own risk. #cuda_malloc_backend: False + # Enable Uvloop or Winloop (default: False) + # Make the program utilize a faster async event loop which can improve performance + # NOTE: It's recommended to enable this, but if something breaks, turn this off. + #uvloop: False + + # Set process to use a higher priority + # For realtime process priority, run as administrator or sudo + # Otherwise, the priority will be set to high + #realtime_process_priority: False + # Options for model overrides and loading # Please read the comments to understand how arguments are handled between initial and API loads model: @@ -124,11 +146,11 @@ model: # NOTE: Effects vary depending on the model. An ideal value is between 512 and 4096 #chunk_size: 2048 - # Set the maximum amount of prompts to process at one time (batch) - # This will be automatically adjusted depending on the cache size. + # Set the maximum amount of prompts to process at one time (default: None/Automatic) + # This will be automatically calculated if left blank. # A max batch size of 1 processes prompts one at a time. # NOTE: Only available for Nvidia ampere (30 series) and above GPUs - #max_batch_size: 20 + #max_batch_size: # Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None) # If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name @@ -179,3 +201,22 @@ model: #loras: #- name: lora1 # scaling: 1.0 + +# Options for embedding models and loading. +# NOTE: Embeddings requires the "extras" feature to be installed +# Install it via "pip install .[extras]" +embeddings: + # Overrides directory to look for embedding models (default: models) + embedding_model_dir: models + + # Device to load embedding models on (default: cpu) + # Possible values: cpu, auto, cuda + # NOTE: It's recommended to load embedding models on the CPU. + # If you'd like to load on an AMD gpu, set this value to "cuda" as well. + embeddings_device: cpu + + # The below parameters only apply for initial loads + # All API based loads do NOT inherit these settings unless specified in use_as_default + + # An initial embedding model to load on the infinity backend (default: None) + embedding_model_name: diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index fd6634c..c337b20 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -12,7 +12,7 @@ services: - NAME=TabbyAPI - NVIDIA_VISIBLE_DEVICES=all volumes: - - ./models:/usr/src/app/models + - ./models:/app/models deploy: resources: reservations: diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py new file mode 100644 index 0000000..334bae2 --- /dev/null +++ b/endpoints/Kobold/router.py @@ -0,0 +1,161 @@ +from sys import maxsize +from fastapi import APIRouter, Depends, Request +from sse_starlette import EventSourceResponse + +from common import model +from common.auth import check_api_key +from common.model import check_model_container +from common.utils import unwrap +from endpoints.core.utils.model import get_current_model +from endpoints.Kobold.types.generation import ( + AbortRequest, + AbortResponse, + CheckGenerateRequest, + GenerateRequest, + GenerateResponse, +) +from endpoints.Kobold.types.model import CurrentModelResponse, MaxLengthResponse +from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse +from endpoints.Kobold.utils.generation import ( + abort_generation, + generation_status, + get_generation, + stream_generation, +) + + +api_name = "KoboldAI" +router = APIRouter(prefix="/api") +urls = { + "Generation": "http://{host}:{port}/api/v1/generate", + "Streaming": "http://{host}:{port}/api/extra/generate/stream", +} + +kai_router = APIRouter() +extra_kai_router = APIRouter() + + +def setup(): + router.include_router(kai_router, prefix="/v1") + router.include_router(kai_router, prefix="/latest", include_in_schema=False) + router.include_router(extra_kai_router, prefix="/extra") + + return router + + +@kai_router.post( + "/generate", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: + response = await get_generation(data, request) + + return response + + +@extra_kai_router.post( + "/generate/stream", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: + response = EventSourceResponse(stream_generation(data, request), ping=maxsize) + + return response + + +@extra_kai_router.post( + "/abort", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def abort_generate(data: AbortRequest) -> AbortResponse: + response = await abort_generation(data.genkey) + + return response + + +@extra_kai_router.get( + "/generate/check", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@extra_kai_router.post( + "/generate/check", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: + response = await generation_status(data.genkey) + + return response + + +@kai_router.get( + "/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] +) +async def current_model() -> CurrentModelResponse: + """Fetches the current model and who owns it.""" + + current_model_card = get_current_model() + return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"} + + +@extra_kai_router.post( + "/tokencount", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: + raw_tokens = model.container.encode_tokens(data.prompt) + tokens = unwrap(raw_tokens, []) + return TokenCountResponse(value=len(tokens), ids=tokens) + + +@kai_router.get( + "/config/max_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@kai_router.get( + "/config/max_context_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@extra_kai_router.get( + "/true_max_context_length", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def get_max_length() -> MaxLengthResponse: + """Fetches the max length of the model.""" + + max_length = model.container.get_model_parameters().get("max_seq_len") + return {"value": max_length} + + +@kai_router.get("/info/version") +async def get_version(): + """Impersonate KAI United.""" + + return {"result": "1.2.5"} + + +@extra_kai_router.get("/version") +async def get_extra_version(): + """Impersonate Koboldcpp.""" + + return {"result": "KoboldCpp", "version": "1.61"} + + +@kai_router.get("/config/soft_prompts_list") +async def get_available_softprompts(): + """Used for KAI compliance.""" + + return {"values": []} + + +@kai_router.get("/config/soft_prompt") +async def get_current_softprompt(): + """Used for KAI compliance.""" + + return {"value": ""} + + +@kai_router.put("/config/soft_prompt") +async def set_current_softprompt(): + """Used for KAI compliance.""" + + return {} diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py new file mode 100644 index 0000000..310484b --- /dev/null +++ b/endpoints/Kobold/types/generation.py @@ -0,0 +1,60 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field +from common import model +from common.sampling import BaseSamplerRequest, get_default_sampler_value +from common.utils import flat_map, unwrap + + +class GenerateRequest(BaseSamplerRequest): + prompt: str + genkey: Optional[str] = None + use_default_badwordsids: Optional[bool] = False + dynatemp_range: Optional[float] = Field( + default_factory=get_default_sampler_value("dynatemp_range") + ) + + def to_gen_params(self, **kwargs): + # Exl2 uses -1 to include all tokens in repetition penalty + if self.penalty_range == 0: + self.penalty_range = -1 + + if self.dynatemp_range: + self.min_temp = self.temperature - self.dynatemp_range + self.max_temp = self.temperature + self.dynatemp_range + + # Move badwordsids into banned tokens for generation + if self.use_default_badwordsids: + bad_words_ids = unwrap( + model.container.generation_config.bad_words_ids, + model.container.hf_config.get_badwordsids(), + ) + + if bad_words_ids: + self.banned_tokens += flat_map(bad_words_ids) + + return super().to_gen_params(**kwargs) + + +class GenerateResponseResult(BaseModel): + text: str + + +class GenerateResponse(BaseModel): + results: List[GenerateResponseResult] = Field(default_factory=list) + + +class StreamGenerateChunk(BaseModel): + token: str + + +class AbortRequest(BaseModel): + genkey: str + + +class AbortResponse(BaseModel): + success: bool + + +class CheckGenerateRequest(BaseModel): + genkey: str diff --git a/endpoints/Kobold/types/model.py b/endpoints/Kobold/types/model.py new file mode 100644 index 0000000..8f7276e --- /dev/null +++ b/endpoints/Kobold/types/model.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class CurrentModelResponse(BaseModel): + result: str + + +class MaxLengthResponse(BaseModel): + value: int diff --git a/endpoints/Kobold/types/token.py b/endpoints/Kobold/types/token.py new file mode 100644 index 0000000..e6639d9 --- /dev/null +++ b/endpoints/Kobold/types/token.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel +from typing import List + + +class TokenCountRequest(BaseModel): + """Represents a KAI tokenization request.""" + + prompt: str + + +class TokenCountResponse(BaseModel): + """Represents a KAI tokenization response.""" + + value: int + ids: List[int] diff --git a/endpoints/Kobold/utils/generation.py b/endpoints/Kobold/utils/generation.py new file mode 100644 index 0000000..5febcff --- /dev/null +++ b/endpoints/Kobold/utils/generation.py @@ -0,0 +1,151 @@ +import asyncio +from asyncio import CancelledError +from fastapi import HTTPException, Request +from loguru import logger +from sse_starlette import ServerSentEvent + +from common import model +from common.networking import ( + get_generator_error, + handle_request_disconnect, + handle_request_error, + request_disconnect_loop, +) +from common.utils import unwrap +from endpoints.Kobold.types.generation import ( + AbortResponse, + GenerateRequest, + GenerateResponse, + GenerateResponseResult, + StreamGenerateChunk, +) + + +generation_cache = {} + + +async def override_request_id(request: Request, data: GenerateRequest): + """Overrides the request ID with a KAI genkey if present.""" + + if data.genkey: + request.state.id = data.genkey + + +def _create_response(text: str): + results = [GenerateResponseResult(text=text)] + return GenerateResponse(results=results) + + +def _create_stream_chunk(text: str): + return StreamGenerateChunk(token=text) + + +async def _stream_collector(data: GenerateRequest, request: Request): + """Common async generator for generation streams.""" + + abort_event = asyncio.Event() + disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + + # Create a new entry in the cache + generation_cache[data.genkey] = {"abort": abort_event, "text": ""} + + try: + logger.info(f"Received Kobold generation request {data.genkey}") + + generator = model.container.generate_gen( + data.prompt, data.genkey, abort_event, **data.to_gen_params() + ) + async for generation in generator: + if disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Kobold generation {data.genkey} cancelled by user." + ) + + text = generation.get("text") + + # Update the generation cache with the new chunk + if text: + generation_cache[data.genkey]["text"] += text + yield text + + if "finish_reason" in generation: + logger.info(f"Finished streaming Kobold request {data.genkey}") + break + except CancelledError: + # If the request disconnects, break out + if not disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Kobold generation {data.genkey} cancelled by user." + ) + finally: + # Cleanup the cache + del generation_cache[data.genkey] + + +async def stream_generation(data: GenerateRequest, request: Request): + """Wrapper for stream generations.""" + + # If the genkey doesn't exist, set it to the request ID + if not data.genkey: + data.genkey = request.state.id + + try: + async for chunk in _stream_collector(data, request): + response = _create_stream_chunk(chunk) + yield ServerSentEvent( + event="message", data=response.model_dump_json(), sep="\n" + ) + except Exception: + yield get_generator_error( + f"Kobold generation {data.genkey} aborted. " + "Please check the server console." + ) + + +async def get_generation(data: GenerateRequest, request: Request): + """Wrapper to get a static generation.""" + + # If the genkey doesn't exist, set it to the request ID + if not data.genkey: + data.genkey = request.state.id + + try: + full_text = "" + async for chunk in _stream_collector(data, request): + full_text += chunk + + response = _create_response(full_text) + return response + except Exception as exc: + error_message = handle_request_error( + f"Completion {request.state.id} aborted. Maybe the model was unloaded? " + "Please check the server console." + ).error.message + + # Server error if there's a generation exception + raise HTTPException(503, error_message) from exc + + +async def abort_generation(genkey: str): + """Aborts a generation from the cache.""" + + abort_event = unwrap(generation_cache.get(genkey), {}).get("abort") + if abort_event: + abort_event.set() + handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.") + + return AbortResponse(success=True) + + +async def generation_status(genkey: str): + """Fetches the status of a generation from the cache.""" + + current_text = unwrap(generation_cache.get(genkey), {}).get("text") + if current_text: + response = _create_response(current_text) + else: + response = GenerateResponse() + + return response diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index f4cc516..d1b0237 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,48 +1,21 @@ import asyncio import pathlib from loguru import logger -from fastapi import APIRouter, Depends, HTTPException, Header, Request +from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize -from typing import Optional -from common import config, model, gen_logging, sampling -from common.auth import check_admin_key, check_api_key, validate_key_permission -from common.downloader import hf_repo_download +from common import config, model +from common.auth import check_api_key +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect -from common.templating import PromptTemplate, get_all_templates -from common.utils import coalesce, unwrap -from endpoints.OAI.types.auth import AuthPermissionResponse +from common.utils import unwrap from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, ChatCompletionResponse, ) -from endpoints.OAI.types.download import DownloadRequest, DownloadResponse -from endpoints.OAI.types.lora import ( - LoraCard, - LoraList, - LoraLoadRequest, - LoraLoadResponse, -) -from endpoints.OAI.types.model import ( - ModelCard, - ModelList, - ModelLoadRequest, - ModelCardParameters, - ModelLoadResponse, -) -from endpoints.OAI.types.sampler_overrides import ( - SamplerOverrideListResponse, - SamplerOverrideSwitchRequest, -) -from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest -from endpoints.OAI.types.token import ( - TokenEncodeRequest, - TokenEncodeResponse, - TokenDecodeRequest, - TokenDecodeResponse, -) +from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, @@ -52,25 +25,19 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) -from endpoints.OAI.utils.model import get_model_list, stream_model_load -from endpoints.OAI.utils.lora import get_lora_list +from endpoints.OAI.utils.embeddings import get_embeddings +api_name = "OAI" router = APIRouter() +urls = { + "Completions": "http://{host}:{port}/v1/completions", + "Chat completions": "http://{host}:{port}/v1/chat/completions", +} -async def check_model_container(): - """FastAPI depends that checks if a model isn't loaded or currently loading.""" - - if model.container is None or not ( - model.container.model_is_loading or model.container.model_loaded - ): - error_message = handle_request_error( - "No models are currently loaded.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) +def setup(): + return router # Completions endpoint @@ -106,12 +73,14 @@ async def completion_request( ping=maxsize, ) else: - generate_task = asyncio.create_task(generate_completion(data, model_path)) + generate_task = asyncio.create_task( + generate_completion(data, request, model_path) + ) response = await run_with_request_disconnect( request, generate_task, - disconnect_message="Completion generation cancelled by user.", + disconnect_message=f"Completion {request.state.id} cancelled by user.", ) return response @@ -206,392 +175,28 @@ async def chat_completion_request( ) else: generate_task = asyncio.create_task( - generate_chat_completion(prompt, data, model_path) + generate_chat_completion(prompt, data, request, model_path) ) response = await run_with_request_disconnect( request, generate_task, - disconnect_message="Chat completion generation cancelled by user.", + disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) return response -# Model list endpoint -@router.get("/v1/models", dependencies=[Depends(check_api_key)]) -@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) -async def list_models() -> ModelList: - """Lists all models in the model directory.""" - model_config = config.model_config() - model_dir = unwrap(model_config.get("model_dir"), "models") - model_path = pathlib.Path(model_dir) - - draft_model_dir = config.draft_model_config().get("draft_model_dir") - - models = get_model_list(model_path.resolve(), draft_model_dir) - if unwrap(model_config.get("use_dummy_models"), False): - models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) - - return models - - -# Currently loaded model endpoint -@router.get( - "/v1/model", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def get_current_model() -> ModelCard: - """Returns the currently loaded model.""" - model_params = model.container.get_model_parameters() - draft_model_params = model_params.pop("draft", {}) - - if draft_model_params: - model_params["draft"] = ModelCard( - id=unwrap(draft_model_params.get("name"), "unknown"), - parameters=ModelCardParameters.model_validate(draft_model_params), - ) - else: - draft_model_params = None - - model_card = ModelCard( - id=unwrap(model_params.pop("name", None), "unknown"), - parameters=ModelCardParameters.model_validate(model_params), - logging=gen_logging.PREFERENCES, - ) - - if draft_model_params: - draft_card = ModelCard( - id=unwrap(draft_model_params.pop("name", None), "unknown"), - parameters=ModelCardParameters.model_validate(draft_model_params), - ) - - model_card.parameters.draft = draft_card - - return model_card - - -@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) -async def list_draft_models() -> ModelList: - """Lists all draft models in the model directory.""" - draft_model_dir = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) - draft_model_path = pathlib.Path(draft_model_dir) - - models = get_model_list(draft_model_path.resolve()) - - return models - - -# Load model endpoint -@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) -async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: - """Loads a model into the model container. This returns an SSE stream.""" - - # Verify request parameters - if not data.name: - error_message = handle_request_error( - "A model name was not provided for load.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models")) - model_path = model_path / data.name - - draft_model_path = None - if data.draft: - if not data.draft.draft_model_name: - error_message = handle_request_error( - "Could not find the draft model name for model load.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - draft_model_path = unwrap( - config.draft_model_config().get("draft_model_dir"), "models" - ) - - if not model_path.exists(): - error_message = handle_request_error( - "Could not find the model path for load. Check model name or config.yml?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - return EventSourceResponse( - stream_model_load(data, model_path, draft_model_path), ping=maxsize - ) - - -# Unload model endpoint +# Embeddings endpoint @router.post( - "/v1/model/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], + "/v1/embeddings", + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], ) -async def unload_model(): - """Unloads the currently loaded model.""" - await model.unload_model(skip_wait=True) - - -@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) -async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: - """Downloads a model from HuggingFace.""" - - try: - download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) - - # For now, the downloader and request data are 1:1 - download_path = await run_with_request_disconnect( - request, - download_task, - "Download request cancelled by user. Files have been cleaned up.", - ) - - return DownloadResponse(download_path=str(download_path)) - except Exception as exc: - error_message = handle_request_error(str(exc)).error.message - - raise HTTPException(400, error_message) from exc - - -# Lora list endpoint -@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) -@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) -async def get_all_loras() -> LoraList: - """Lists all LoRAs in the lora directory.""" - lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) - loras = get_lora_list(lora_path.resolve()) - - return loras - - -# Currently loaded loras endpoint -@router.get( - "/v1/lora", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def get_active_loras() -> LoraList: - """Returns the currently loaded loras.""" - active_loras = LoraList( - data=[ - LoraCard( - id=pathlib.Path(lora.lora_path).parent.name, - scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, - ) - for lora in model.container.get_loras() - ] +async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: + embeddings_task = asyncio.create_task(get_embeddings(data, request)) + response = await run_with_request_disconnect( + request, + embeddings_task, + f"Embeddings request {request.state.id} cancelled by user.", ) - return active_loras - - -# Load lora endpoint -@router.post( - "/v1/lora/load", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: - """Loads a LoRA into the model container.""" - - if not data.loras: - error_message = handle_request_error( - "List of loras to load is not found.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) - if not lora_dir.exists(): - error_message = handle_request_error( - "A parent lora directory does not exist for load. Check your config.yml?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - load_result = await model.load_loras( - lora_dir, **data.model_dump(), skip_wait=data.skip_queue - ) - - return LoraLoadResponse( - success=unwrap(load_result.get("success"), []), - failure=unwrap(load_result.get("failure"), []), - ) - - -# Unload lora endpoint -@router.post( - "/v1/lora/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_loras(): - """Unloads the currently loaded loras.""" - - await model.unload_loras() - - -# Encode tokens endpoint -@router.post( - "/v1/token/encode", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: - """Encodes a string or chat completion messages into tokens.""" - - if isinstance(data.text, str): - text = data.text - else: - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True) - ) - - template_vars = { - "messages": data.text, - "add_generation_prompt": False, - **special_tokens_dict, - } - - text, _ = model.container.prompt_template.render(template_vars) - - raw_tokens = model.container.encode_tokens(text, **data.get_params()) - tokens = unwrap(raw_tokens, []) - response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) - return response - - -# Decode tokens endpoint -@router.post( - "/v1/token/decode", - dependencies=[Depends(check_api_key), Depends(check_model_container)], -) -async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: - """Decodes tokens into a string.""" - message = model.container.decode_tokens(data.tokens, **data.get_params()) - response = TokenDecodeResponse(text=unwrap(message, "")) - - return response - - -@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) -async def get_key_permission( - x_admin_key: Optional[str] = Header(None), - x_api_key: Optional[str] = Header(None), - authorization: Optional[str] = Header(None), -) -> AuthPermissionResponse: - """ - Gets the access level/permission of a provided key in headers. - - Priority: - - X-api-key - - X-admin-key - - Authorization - """ - - test_key = coalesce(x_admin_key, x_api_key, authorization) - - try: - permission = await validate_key_permission(test_key) - return AuthPermissionResponse(permission=permission) - except ValueError as exc: - error_message = handle_request_error(str(exc)).error.message - - raise HTTPException(400, error_message) from exc - - -@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) -@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) -async def get_templates() -> TemplateList: - templates = get_all_templates() - template_strings = [template.stem for template in templates] - return TemplateList(data=template_strings) - - -@router.post( - "/v1/template/switch", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def switch_template(data: TemplateSwitchRequest): - """Switch the currently loaded template""" - if not data.name: - error_message = handle_request_error( - "New template name not found.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - try: - model.container.prompt_template = PromptTemplate.from_file(data.name) - except FileNotFoundError as e: - error_message = handle_request_error( - f"The template name {data.name} doesn't exist. Check the spelling?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) from e - - -@router.post( - "/v1/template/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_template(): - """Unloads the currently selected template""" - - model.container.prompt_template = None - - -# Sampler override endpoints -@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) -@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) -async def list_sampler_overrides() -> SamplerOverrideListResponse: - """API wrapper to list all currently applied sampler overrides""" - - return SamplerOverrideListResponse( - presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump() - ) - - -@router.post( - "/v1/sampling/override/switch", - dependencies=[Depends(check_admin_key)], -) -async def switch_sampler_override(data: SamplerOverrideSwitchRequest): - """Switch the currently loaded override preset""" - - if data.preset: - try: - sampling.overrides_from_file(data.preset) - except FileNotFoundError as e: - error_message = handle_request_error( - f"Sampler override preset with name {data.preset} does not exist. " - + "Check the spelling?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) from e - elif data.overrides: - sampling.overrides_from_dict(data.overrides) - else: - error_message = handle_request_error( - "A sampler override preset or dictionary wasn't provided.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - -@router.post( - "/v1/sampling/override/unload", - dependencies=[Depends(check_admin_key)], -) -async def unload_sampler_override(): - """Unloads the currently selected override preset""" - - sampling.overrides_from_dict({}) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index b66277b..ce19733 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -65,3 +65,4 @@ class ChatCompletionStreamChunk(BaseModel): created: int = Field(default_factory=lambda: int(time())) model: str object: str = "chat.completion.chunk" + usage: Optional[UsageStats] = None diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index d44e41a..6970adf 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel): type: str = "text" +class ChatCompletionStreamOptions(BaseModel): + include_usage: Optional[bool] = False + + class CommonCompletionRequest(BaseSamplerRequest): """Represents a common completion request.""" @@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False + stream_options: Optional[ChatCompletionStreamOptions] = None logprobs: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("logprobs", 0) ) diff --git a/endpoints/OAI/types/embedding.py b/endpoints/OAI/types/embedding.py new file mode 100644 index 0000000..7d5779f --- /dev/null +++ b/endpoints/OAI/types/embedding.py @@ -0,0 +1,42 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class EmbeddingsRequest(BaseModel): + input: List[str] = Field( + ..., description="List of input texts to generate embeddings for." + ) + encoding_format: str = Field( + "float", + description="Encoding format for the embeddings. " + "Can be 'float' or 'base64'.", + ) + model: Optional[str] = Field( + None, + description="Name of the embedding model to use. " + "If not provided, the default model will be used.", + ) + + +class EmbeddingObject(BaseModel): + object: str = Field("embedding", description="Type of the object.") + embedding: List[float] = Field( + ..., description="Embedding values as a list of floats." + ) + index: int = Field( + ..., description="Index of the input text corresponding to " "the embedding." + ) + + +class EmbeddingsResponse(BaseModel): + object: str = Field("list", description="Type of the response object.") + data: List[EmbeddingObject] = Field(..., description="List of embedding objects.") + model: str = Field(..., description="Name of the embedding model used.") + usage: UsageInfo = Field(..., description="Information about token usage.") diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9e82b1b..80b5715 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -5,7 +5,6 @@ import pathlib from asyncio import CancelledError from copy import deepcopy from typing import List, Optional -from uuid import uuid4 from fastapi import HTTPException, Request from jinja2 import TemplateError @@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import ( ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats +from endpoints.OAI.utils.completion import _stream_collector -def _create_response(generations: List[dict], model_name: Optional[str]): +def _create_response( + request_id: str, generations: List[dict], model_name: Optional[str] +): """Create a chat completion response from the provided text.""" prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) @@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]): choices.append(choice) response = ChatCompletionResponse( + id=f"chatcmpl-{request_id}", choices=choices, model=unwrap(model_name, ""), usage=UsageStats( @@ -90,25 +93,40 @@ def _create_response(generations: List[dict], model_name: Optional[str]): def _create_stream_chunk( - const_id: str, + request_id: str, generation: Optional[dict] = None, model_name: Optional[str] = None, + is_usage_chunk: bool = False, ): """Create a chat completion stream chunk from the provided text.""" index = generation.get("index") - logprob_response = None + choices = [] + usage_stats = None - if "finish_reason" in generation: + if is_usage_chunk: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) + + usage_stats = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + elif "finish_reason" in generation: choice = ChatCompletionStreamChoice( index=index, finish_reason=generation.get("finish_reason"), ) + + choices.append(choice) else: message = ChatCompletionMessage( role="assistant", content=unwrap(generation.get("text"), "") ) + logprob_response = None + token_probs = unwrap(generation.get("token_probs"), {}) if token_probs: logprobs = unwrap(generation.get("logprobs"), {}) @@ -132,8 +150,13 @@ def _create_stream_chunk( logprobs=logprob_response, ) + choices.append(choice) + chunk = ChatCompletionStreamChunk( - id=const_id, choices=[choice], model=unwrap(model_name, "") + id=f"chatcmpl-{request_id}", + choices=choices, + model=unwrap(model_name, ""), + usage=usage_stats, ) return chunk @@ -215,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest): raise HTTPException(400, error_message) from exc -async def _stream_collector( - task_idx: int, - gen_queue: asyncio.Queue, - prompt: str, - abort_event: asyncio.Event, - **kwargs, -): - """Collects a stream and places results in a common queue""" - - try: - new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) - async for generation in new_generation: - generation["index"] = task_idx - - await gen_queue.put(generation) - - if "finish_reason" in generation: - break - except Exception as e: - await gen_queue.put(e) - - async def stream_generate_chat_completion( prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path ): """Generator for the generation process.""" - const_id = f"chatcmpl-{uuid4().hex}" abort_event = asyncio.Event() gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: + logger.info(f"Received chat completion streaming request {request.state.id}") + gen_params = data.to_gen_params() for n in range(0, data.n): @@ -257,7 +259,14 @@ async def stream_generate_chat_completion( task_gen_params = gen_params gen_task = asyncio.create_task( - _stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params) + _stream_collector( + n, + gen_queue, + prompt, + request.state.id, + abort_event, + **task_gen_params, + ) ) gen_tasks.append(gen_task) @@ -266,7 +275,9 @@ async def stream_generate_chat_completion( while True: if disconnect_task.done(): abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + handle_request_disconnect( + f"Chat completion generation {request.state.id} cancelled by user." + ) generation = await gen_queue.get() @@ -274,17 +285,35 @@ async def stream_generate_chat_completion( if isinstance(generation, Exception): raise generation - response = _create_stream_chunk(const_id, generation, model_path.name) + response = _create_stream_chunk( + request.state.id, generation, model_path.name + ) yield response.model_dump_json() # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): + # Send a usage chunk + if data.stream_options and data.stream_options.include_usage: + usage_chunk = _create_stream_chunk( + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, + ) + yield usage_chunk.model_dump_json() + + logger.info( + f"Finished chat completion streaming request {request.state.id}" + ) + + yield "[DONE]" break except CancelledError: # Get out if the request gets disconnected - abort_event.set() - handle_request_disconnect("Chat completion generation cancelled by user.") + if not disconnect_task.done(): + abort_event.set() + handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." @@ -292,7 +321,7 @@ async def stream_generate_chat_completion( async def generate_chat_completion( - prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path + prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path ): gen_tasks: List[asyncio.Task] = [] gen_params = data.to_gen_params() @@ -307,16 +336,23 @@ async def generate_chat_completion( task_gen_params = gen_params gen_tasks.append( - asyncio.create_task(model.container.generate(prompt, **task_gen_params)) + asyncio.create_task( + model.container.generate( + prompt, request.state.id, **task_gen_params + ) + ) ) generations = await asyncio.gather(*gen_tasks) - response = _create_response(generations, model_path.name) + response = _create_response(request.state.id, generations, model_path.name) + + logger.info(f"Finished chat completion request {request.state.id}") return response except Exception as exc: error_message = handle_request_error( - "Chat completion aborted. Maybe the model was unloaded? " + f"Chat completion {request.state.id} aborted. " + "Maybe the model was unloaded? " "Please check the server console." ).error.message diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 2b5dfbf..52c2bb4 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,6 +7,8 @@ from copy import deepcopy from fastapi import HTTPException, Request from typing import List, Union +from loguru import logger + from common import model from common.networking import ( get_generator_error, @@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import ( from endpoints.OAI.types.common import UsageStats -def _create_response(generations: Union[dict, List[dict]], model_name: str = ""): +def _create_response( + request_id: str, generations: Union[dict, List[dict]], model_name: str = "" +): """Create a completion response from the provided choices.""" # Convert the single choice object into a list @@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "") completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) response = CompletionResponse( + id=f"cmpl-{request_id}", choices=choices, model=model_name, usage=UsageStats( @@ -77,13 +82,16 @@ async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, prompt: str, + request_id: str, abort_event: asyncio.Event, **kwargs, ): """Collects a stream and places results in a common queue""" try: - new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) + new_generation = model.container.generate_gen( + prompt, request_id, abort_event, **kwargs + ) async for generation in new_generation: generation["index"] = task_idx @@ -106,6 +114,8 @@ async def stream_generate_completion( disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: + logger.info(f"Received streaming completion request {request.state.id}") + gen_params = data.to_gen_params() for n in range(0, data.n): @@ -116,7 +126,12 @@ async def stream_generate_completion( gen_task = asyncio.create_task( _stream_collector( - n, gen_queue, data.prompt, abort_event, **task_gen_params + n, + gen_queue, + data.prompt, + request.state.id, + abort_event, + **task_gen_params, ) ) @@ -126,7 +141,9 @@ async def stream_generate_completion( while True: if disconnect_task.done(): abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) generation = await gen_queue.get() @@ -134,31 +151,39 @@ async def stream_generate_completion( if isinstance(generation, Exception): raise generation - response = _create_response(generation, model_path.name) + response = _create_response(request.state.id, generation, model_path.name) yield response.model_dump_json() # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): yield "[DONE]" + logger.info(f"Finished streaming completion request {request.state.id}") break except CancelledError: # Get out if the request gets disconnected - abort_event.set() - handle_request_disconnect("Completion generation cancelled by user.") + if not disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) except Exception: yield get_generator_error( - "Completion aborted. Please check the server console." + f"Completion {request.state.id} aborted. Please check the server console." ) -async def generate_completion(data: CompletionRequest, model_path: pathlib.Path): +async def generate_completion( + data: CompletionRequest, request: Request, model_path: pathlib.Path +): """Non-streaming generate for completions""" gen_tasks: List[asyncio.Task] = [] gen_params = data.to_gen_params() try: + logger.info(f"Recieved completion request {request.state.id}") + for n in range(0, data.n): # Deepcopy gen params above the first index # to ensure nested structures aren't shared @@ -169,17 +194,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path) gen_tasks.append( asyncio.create_task( - model.container.generate(data.prompt, **task_gen_params) + model.container.generate( + data.prompt, request.state.id, **task_gen_params + ) ) ) generations = await asyncio.gather(*gen_tasks) - response = _create_response(generations, model_path.name) + response = _create_response(request.state.id, generations, model_path.name) + + logger.info(f"Finished completion request {request.state.id}") return response except Exception as exc: error_message = handle_request_error( - "Completion aborted. Maybe the model was unloaded? " + f"Completion {request.state.id} aborted. Maybe the model was unloaded? " "Please check the server console." ).error.message diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py new file mode 100644 index 0000000..5b43953 --- /dev/null +++ b/endpoints/OAI/utils/embeddings.py @@ -0,0 +1,64 @@ +""" +This file is derived from +[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py) +and modified. +The changes introduced are: Suppression of progress bar, +typing/pydantic classes moved into this file, +embeddings function declared async. +""" + +import base64 +from fastapi import Request +import numpy as np +from loguru import logger + +from common import model +from endpoints.OAI.types.embedding import ( + EmbeddingObject, + EmbeddingsRequest, + EmbeddingsResponse, + UsageInfo, +) + + +def float_list_to_base64(float_array: np.ndarray) -> str: + """ + Converts the provided list to a float32 array for OpenAI + Ex. float_array = np.array(float_list, dtype="float32") + """ + + # Encode raw bytes into base64 + encoded_bytes = base64.b64encode(float_array.tobytes()) + + # Turn raw base64 encoded bytes into ASCII + ascii_string = encoded_bytes.decode("ascii") + return ascii_string + + +async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict: + model_path = model.embeddings_container.model_dir + + logger.info(f"Recieved embeddings request {request.state.id}") + embedding_data = await model.embeddings_container.generate(data.input) + + # OAI expects a return of base64 if the input is base64 + embedding_object = [ + EmbeddingObject( + embedding=float_list_to_base64(emb) + if data.encoding_format == "base64" + else emb.tolist(), + index=n, + ) + for n, emb in enumerate(embedding_data.get("embeddings")) + ] + + usage = embedding_data.get("usage") + response = EmbeddingsResponse( + data=embedding_object, + model=model_path.name, + usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), + ) + + logger.info(f"Finished embeddings request {request.state.id}") + + return response diff --git a/endpoints/OAI/utils/lora.py b/endpoints/OAI/utils/lora.py deleted file mode 100644 index d00910f..0000000 --- a/endpoints/OAI/utils/lora.py +++ /dev/null @@ -1,14 +0,0 @@ -import pathlib - -from endpoints.OAI.types.lora import LoraCard, LoraList - - -def get_lora_list(lora_path: pathlib.Path): - """Get the list of Lora cards from the provided path.""" - lora_list = LoraList() - for path in lora_path.iterdir(): - if path.is_dir(): - lora_card = LoraCard(id=path.name) - lora_list.data.append(lora_card) - - return lora_list diff --git a/endpoints/core/router.py b/endpoints/core/router.py new file mode 100644 index 0000000..a857a1a --- /dev/null +++ b/endpoints/core/router.py @@ -0,0 +1,521 @@ +import asyncio +import pathlib +from sys import maxsize +from fastapi import APIRouter, Depends, HTTPException, Request +from sse_starlette import EventSourceResponse + +from common import config, model, sampling +from common.auth import check_admin_key, check_api_key, get_key_permission +from common.downloader import hf_repo_download +from common.model import check_embeddings_container, check_model_container +from common.networking import handle_request_error, run_with_request_disconnect +from common.templating import PromptTemplate, get_all_templates +from common.utils import unwrap +from endpoints.core.types.auth import AuthPermissionResponse +from endpoints.core.types.download import DownloadRequest, DownloadResponse +from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse +from endpoints.core.types.model import ( + EmbeddingModelLoadRequest, + ModelCard, + ModelList, + ModelLoadRequest, + ModelLoadResponse, +) +from endpoints.core.types.sampler_overrides import ( + SamplerOverrideListResponse, + SamplerOverrideSwitchRequest, +) +from endpoints.core.types.template import TemplateList, TemplateSwitchRequest +from endpoints.core.types.token import ( + TokenDecodeRequest, + TokenDecodeResponse, + TokenEncodeRequest, + TokenEncodeResponse, +) +from endpoints.core.utils.lora import get_active_loras, get_lora_list +from endpoints.core.utils.model import ( + get_current_model, + get_current_model_list, + get_model_list, + stream_model_load, +) + + +router = APIRouter() + + +# Model list endpoint +@router.get("/v1/models", dependencies=[Depends(check_api_key)]) +@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) +async def list_models(request: Request) -> ModelList: + """ + Lists all models in the model directory. + + Requires an admin key to see all models. + """ + + model_config = config.model_config() + model_dir = unwrap(model_config.get("model_dir"), "models") + model_path = pathlib.Path(model_dir) + + draft_model_dir = config.draft_model_config().get("draft_model_dir") + + if get_key_permission(request) == "admin": + models = get_model_list(model_path.resolve(), draft_model_dir) + else: + models = await get_current_model_list() + + if unwrap(model_config.get("use_dummy_models"), False): + models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) + + return models + + +# Currently loaded model endpoint +@router.get( + "/v1/model", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def current_model() -> ModelCard: + """Returns the currently loaded model.""" + + return get_current_model() + + +@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) +async def list_draft_models(request: Request) -> ModelList: + """ + Lists all draft models in the model directory. + + Requires an admin key to see all draft models. + """ + + if get_key_permission(request) == "admin": + draft_model_dir = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) + draft_model_path = pathlib.Path(draft_model_dir) + + models = get_model_list(draft_model_path.resolve()) + else: + models = await get_current_model_list(is_draft=True) + + return models + + +# Load model endpoint +@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) +async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: + """Loads a model into the model container. This returns an SSE stream.""" + + # Verify request parameters + if not data.name: + error_message = handle_request_error( + "A model name was not provided for load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models")) + model_path = model_path / data.name + + draft_model_path = None + if data.draft: + if not data.draft.draft_model_name: + error_message = handle_request_error( + "Could not find the draft model name for model load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + draft_model_path = unwrap( + config.draft_model_config().get("draft_model_dir"), "models" + ) + + if not model_path.exists(): + error_message = handle_request_error( + "Could not find the model path for load. Check model name or config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + return EventSourceResponse( + stream_model_load(data, model_path, draft_model_path), ping=maxsize + ) + + +# Unload model endpoint +@router.post( + "/v1/model/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_model(): + """Unloads the currently loaded model.""" + await model.unload_model(skip_wait=True) + + +@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) +async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: + """Downloads a model from HuggingFace.""" + + try: + download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) + + # For now, the downloader and request data are 1:1 + download_path = await run_with_request_disconnect( + request, + download_task, + "Download request cancelled by user. Files have been cleaned up.", + ) + + return DownloadResponse(download_path=str(download_path)) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + +# Lora list endpoint +@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) +@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) +async def list_all_loras(request: Request) -> LoraList: + """ + Lists all LoRAs in the lora directory. + + Requires an admin key to see all LoRAs. + """ + + if get_key_permission(request) == "admin": + lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + loras = get_lora_list(lora_path.resolve()) + else: + loras = get_active_loras() + + return loras + + +# Currently loaded loras endpoint +@router.get( + "/v1/lora", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def active_loras() -> LoraList: + """Returns the currently loaded loras.""" + + return get_active_loras() + + +# Load lora endpoint +@router.post( + "/v1/lora/load", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: + """Loads a LoRA into the model container.""" + + if not data.loras: + error_message = handle_request_error( + "List of loras to load is not found.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) + if not lora_dir.exists(): + error_message = handle_request_error( + "A parent lora directory does not exist for load. Check your config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + load_result = await model.load_loras( + lora_dir, **data.model_dump(), skip_wait=data.skip_queue + ) + + return LoraLoadResponse( + success=unwrap(load_result.get("success"), []), + failure=unwrap(load_result.get("failure"), []), + ) + + +# Unload lora endpoint +@router.post( + "/v1/lora/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_loras(): + """Unloads the currently loaded loras.""" + + await model.unload_loras() + + +@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)]) +async def list_embedding_models(request: Request) -> ModelList: + """ + Lists all embedding models in the model directory. + + Requires an admin key to see all embedding models. + """ + + if get_key_permission(request) == "admin": + embedding_model_dir = unwrap( + config.embeddings_config().get("embedding_model_dir"), "models" + ) + embedding_model_path = pathlib.Path(embedding_model_dir) + + models = get_model_list(embedding_model_path.resolve()) + else: + models = await get_current_model_list(model_type="embedding") + + return models + + +@router.get( + "/v1/model/embedding", + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], +) +async def get_embedding_model() -> ModelCard: + """Returns the currently loaded embedding model.""" + models = await get_current_model_list(model_type="embedding") + + return models.data[0] + + +@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) +async def load_embedding_model( + request: Request, data: EmbeddingModelLoadRequest +) -> ModelLoadResponse: + # Verify request parameters + if not data.name: + error_message = handle_request_error( + "A model name was not provided for load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + embedding_model_dir = pathlib.Path( + unwrap(config.model_config().get("embedding_model_dir"), "models") + ) + embedding_model_path = embedding_model_dir / data.name + + if not embedding_model_path.exists(): + error_message = handle_request_error( + "Could not find the embedding model path for load. " + + "Check model name or config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + try: + load_task = asyncio.create_task( + model.load_embedding_model(embedding_model_path, **data.model_dump()) + ) + await run_with_request_disconnect( + request, load_task, "Embedding model load request cancelled by user." + ) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + response = ModelLoadResponse( + model_type="embedding_model", module=1, modules=1, status="finished" + ) + + return response + + +@router.post( + "/v1/model/embedding/unload", + dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)], +) +async def unload_embedding_model(): + """Unloads the current embedding model.""" + + await model.unload_embedding_model() + + +# Encode tokens endpoint +@router.post( + "/v1/token/encode", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: + """Encodes a string or chat completion messages into tokens.""" + + if isinstance(data.text, str): + text = data.text + else: + special_tokens_dict = model.container.get_special_tokens( + unwrap(data.add_bos_token, True) + ) + + template_vars = { + "messages": data.text, + "add_generation_prompt": False, + **special_tokens_dict, + } + + text, _ = model.container.prompt_template.render(template_vars) + + raw_tokens = model.container.encode_tokens(text, **data.get_params()) + tokens = unwrap(raw_tokens, []) + response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) + + return response + + +# Decode tokens endpoint +@router.post( + "/v1/token/decode", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: + """Decodes tokens into a string.""" + + message = model.container.decode_tokens(data.tokens, **data.get_params()) + response = TokenDecodeResponse(text=unwrap(message, "")) + + return response + + +@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) +async def key_permission(request: Request) -> AuthPermissionResponse: + """ + Gets the access level/permission of a provided key in headers. + + Priority: + - X-admin-key + - X-api-key + - Authorization + """ + + try: + permission = get_key_permission(request) + return AuthPermissionResponse(permission=permission) + except ValueError as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + +@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) +@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +async def list_templates(request: Request) -> TemplateList: + """ + Get a list of all templates. + + Requires an admin key to see all templates. + """ + + template_strings = [] + if get_key_permission(request) == "admin": + templates = get_all_templates() + template_strings = [template.stem for template in templates] + else: + if model.container and model.container.prompt_template: + template_strings.append(model.container.prompt_template.name) + + return TemplateList(data=template_strings) + + +@router.post( + "/v1/template/switch", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def switch_template(data: TemplateSwitchRequest): + """Switch the currently loaded template.""" + + if not data.name: + error_message = handle_request_error( + "New template name not found.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + try: + model.container.prompt_template = PromptTemplate.from_file(data.name) + except FileNotFoundError as e: + error_message = handle_request_error( + f"The template name {data.name} doesn't exist. Check the spelling?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) from e + + +@router.post( + "/v1/template/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_template(): + """Unloads the currently selected template""" + + model.container.prompt_template = None + + +# Sampler override endpoints +@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) +@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: + """ + List all currently applied sampler overrides. + + Requires an admin key to see all override presets. + """ + + if get_key_permission(request) == "admin": + presets = sampling.get_all_presets() + else: + presets = [] + + return SamplerOverrideListResponse( + presets=presets, **sampling.overrides_container.model_dump() + ) + + +@router.post( + "/v1/sampling/override/switch", + dependencies=[Depends(check_admin_key)], +) +async def switch_sampler_override(data: SamplerOverrideSwitchRequest): + """Switch the currently loaded override preset""" + + if data.preset: + try: + sampling.overrides_from_file(data.preset) + except FileNotFoundError as e: + error_message = handle_request_error( + f"Sampler override preset with name {data.preset} does not exist. " + + "Check the spelling?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) from e + elif data.overrides: + sampling.overrides_from_dict(data.overrides) + else: + error_message = handle_request_error( + "A sampler override preset or dictionary wasn't provided.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + +@router.post( + "/v1/sampling/override/unload", + dependencies=[Depends(check_admin_key)], +) +async def unload_sampler_override(): + """Unloads the currently selected override preset""" + + sampling.overrides_from_dict({}) diff --git a/endpoints/OAI/types/auth.py b/endpoints/core/types/auth.py similarity index 100% rename from endpoints/OAI/types/auth.py rename to endpoints/core/types/auth.py diff --git a/endpoints/OAI/types/download.py b/endpoints/core/types/download.py similarity index 94% rename from endpoints/OAI/types/download.py rename to endpoints/core/types/download.py index ac681bf..cf49501 100644 --- a/endpoints/OAI/types/download.py +++ b/endpoints/core/types/download.py @@ -17,6 +17,7 @@ class DownloadRequest(BaseModel): include: List[str] = Field(default_factory=_generate_include_list) exclude: List[str] = Field(default_factory=list) chunk_limit: Optional[int] = None + timeout: Optional[int] = None class DownloadResponse(BaseModel): diff --git a/endpoints/OAI/types/lora.py b/endpoints/core/types/lora.py similarity index 100% rename from endpoints/OAI/types/lora.py rename to endpoints/core/types/lora.py diff --git a/endpoints/OAI/types/model.py b/endpoints/core/types/model.py similarity index 92% rename from endpoints/OAI/types/model.py rename to endpoints/core/types/model.py index 30730b8..1e2eb46 100644 --- a/endpoints/OAI/types/model.py +++ b/endpoints/core/types/model.py @@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel): # Config arguments draft_rope_scale: Optional[float] = Field( default_factory=lambda: get_config_default( - "draft_rope_scale", 1.0, is_draft=True + "draft_rope_scale", 1.0, model_type="draft" ) ) draft_rope_alpha: Optional[float] = Field( description="Automatically calculated if not present", default_factory=lambda: get_config_default( - "draft_rope_alpha", None, is_draft=True + "draft_rope_alpha", None, model_type="draft" ), examples=[1.0], ) draft_cache_mode: Optional[str] = Field( default_factory=lambda: get_config_default( - "draft_cache_mode", "FP16", is_draft=True + "draft_cache_mode", "FP16", model_type="draft" ) ) @@ -137,6 +137,15 @@ class ModelLoadRequest(BaseModel): skip_queue: Optional[bool] = False +class EmbeddingModelLoadRequest(BaseModel): + name: str + embeddings_device: Optional[str] = Field( + default_factory=lambda: get_config_default( + "embeddings_device", model_type="embedding" + ) + ) + + class ModelLoadResponse(BaseModel): """Represents a model load response.""" diff --git a/endpoints/OAI/types/sampler_overrides.py b/endpoints/core/types/sampler_overrides.py similarity index 100% rename from endpoints/OAI/types/sampler_overrides.py rename to endpoints/core/types/sampler_overrides.py diff --git a/endpoints/OAI/types/template.py b/endpoints/core/types/template.py similarity index 100% rename from endpoints/OAI/types/template.py rename to endpoints/core/types/template.py diff --git a/endpoints/OAI/types/token.py b/endpoints/core/types/token.py similarity index 100% rename from endpoints/OAI/types/token.py rename to endpoints/core/types/token.py diff --git a/endpoints/core/utils/lora.py b/endpoints/core/utils/lora.py new file mode 100644 index 0000000..c8c9cc4 --- /dev/null +++ b/endpoints/core/utils/lora.py @@ -0,0 +1,30 @@ +import pathlib + +from common import model +from endpoints.core.types.lora import LoraCard, LoraList + + +def get_lora_list(lora_path: pathlib.Path): + """Get the list of Lora cards from the provided path.""" + lora_list = LoraList() + for path in lora_path.iterdir(): + if path.is_dir(): + lora_card = LoraCard(id=path.name) + lora_list.data.append(lora_card) + + return lora_list + + +def get_active_loras(): + if model.container: + active_loras = [ + LoraCard( + id=pathlib.Path(lora.lora_path).parent.name, + scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, + ) + for lora in model.container.get_loras() + ] + else: + active_loras = [] + + return LoraList(data=active_loras) diff --git a/endpoints/OAI/utils/model.py b/endpoints/core/utils/model.py similarity index 56% rename from endpoints/OAI/utils/model.py rename to endpoints/core/utils/model.py index 0502193..fc61337 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/core/utils/model.py @@ -2,11 +2,12 @@ import pathlib from asyncio import CancelledError from typing import Optional -from common import model +from common import gen_logging, model from common.networking import get_generator_error, handle_request_disconnect - -from endpoints.OAI.types.model import ( +from common.utils import unwrap +from endpoints.core.types.model import ( ModelCard, + ModelCardParameters, ModelList, ModelLoadRequest, ModelLoadResponse, @@ -31,6 +32,61 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list +async def get_current_model_list(model_type: str = "model"): + """ + Gets the current model in list format and with path only. + + Unified for fetching both models and embedding models. + """ + + current_models = [] + model_path = None + + # Make sure the model container exists + if model_type == "model" or model_type == "draft": + if model.container: + model_path = model.container.get_model_path(model_type == "draft") + elif model_type == "embedding": + if model.embeddings_container: + model_path = model.embeddings_container.model_dir + + if model_path: + current_models.append(ModelCard(id=model_path.name)) + + return ModelList(data=current_models) + + +def get_current_model(): + """Gets the current model with all parameters.""" + + model_params = model.container.get_model_parameters() + draft_model_params = model_params.pop("draft", {}) + + if draft_model_params: + model_params["draft"] = ModelCard( + id=unwrap(draft_model_params.get("name"), "unknown"), + parameters=ModelCardParameters.model_validate(draft_model_params), + ) + else: + draft_model_params = None + + model_card = ModelCard( + id=unwrap(model_params.pop("name", None), "unknown"), + parameters=ModelCardParameters.model_validate(model_params), + logging=gen_logging.PREFERENCES, + ) + + if draft_model_params: + draft_card = ModelCard( + id=unwrap(draft_model_params.pop("name", None), "unknown"), + parameters=ModelCardParameters.model_validate(draft_model_params), + ) + + model_card.parameters.draft = draft_card + + return model_card + + async def stream_model_load( data: ModelLoadRequest, model_path: pathlib.Path, diff --git a/endpoints/server.py b/endpoints/server.py index 7ceb208..0b3edfb 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,40 +1,74 @@ +import asyncio +from typing import Optional import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger +from common import config from common.logger import UVICORN_LOG_CONFIG -from endpoints.OAI.router import router as OAIRouter - -app = FastAPI( - title="TabbyAPI", - summary="An OAI compatible exllamav2 API that's both lightweight and fast", - description=( - "This docs page is not meant to send requests! Please use a service " - "like Postman or a frontend UI." - ), -) - -# ALlow CORS requests -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +from common.networking import get_global_depends +from common.utils import unwrap +from endpoints.Kobold import router as KoboldRouter +from endpoints.OAI import router as OAIRouter +from endpoints.core.router import router as CoreRouter -def setup_app(): +def setup_app(host: Optional[str] = None, port: Optional[int] = None): """Includes the correct routers for startup""" - app.include_router(OAIRouter) + app = FastAPI( + title="TabbyAPI", + summary="An OAI compatible exllamav2 API that's both lightweight and fast", + description=( + "This docs page is not meant to send requests! Please use a service " + "like Postman or a frontend UI." + ), + dependencies=get_global_depends(), + ) + + # ALlow CORS requests + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + api_servers = unwrap(config.network_config().get("api_servers"), []) + + # Map for API id to server router + router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter} + + # Include the OAI api by default + if api_servers: + for server in api_servers: + selected_server = router_mapping.get(server.lower()) + + if selected_server: + app.include_router(selected_server.setup()) + + logger.info(f"Starting {selected_server.api_name} API") + for path, url in selected_server.urls.items(): + formatted_url = url.format(host=host, port=port) + logger.info(f"{path}: {formatted_url}") + else: + app.include_router(OAIRouter.setup()) + for path, url in OAIRouter.urls.items(): + formatted_url = url.format(host=host, port=port) + logger.info(f"{path}: {formatted_url}") + + # Include core API request paths + app.include_router(CoreRouter) + + return app def export_openapi(): """Function to return the OpenAPI JSON from the API server""" - setup_app() + app = setup_app() return app.openapi() @@ -43,17 +77,21 @@ async def start_api(host: str, port: int): # TODO: Move OAI API to a separate folder logger.info(f"Developer documentation: http://{host}:{port}/redoc") - logger.info(f"Completions: http://{host}:{port}/v1/completions") - logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") + # logger.info(f"Completions: http://{host}:{port}/v1/completions") + # logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") # Setup app - setup_app() + app = setup_app(host, port) + + # Get the current event loop + loop = asyncio.get_running_loop() config = uvicorn.Config( app, host=host, port=port, log_config=UVICORN_LOG_CONFIG, + loop=loop, ) server = uvicorn.Server(config) diff --git a/main.py b/main.py index e089d81..b0c5108 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,10 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" import asyncio -import aiofiles import json import os import pathlib +import platform import signal from loguru import logger from typing import Optional @@ -23,51 +23,8 @@ if not do_export_openapi: from backends.exllamav2.utils import check_exllama_version -async def entrypoint(args: Optional[dict] = None): - """Entry function for program startup""" - - setup_logger() - - # Set up signal aborting - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - if os.getenv("EXPORT_OPENAPI", "").lower() in ("true", "1"): - openapi_json = export_openapi() - - async with aiofiles.open("openapi.json", "w") as f: - await f.write(json.dumps(openapi_json)) - logger.info("Successfully wrote OpenAPI spec to openapi.json") - - return - - # Load from YAML config - config.from_file(pathlib.Path("config.yml")) - - # Parse and override config from args - if args is None: - parser = init_argparser() - args = convert_args_to_dict(parser.parse_args(), parser) - - config.from_args(args) - - developer_config = config.developer_config() - - # Check exllamav2 version and give a descriptive error if it's too old - # Skip if launching unsafely - - if unwrap(developer_config.get("unsafe_launch"), False): - logger.warning( - "UNSAFE: Skipping ExllamaV2 version check.\n" - "If you aren't a developer, please keep this off!" - ) - else: - check_exllama_version() - - # Enable CUDA malloc backend - if unwrap(developer_config.get("cuda_malloc_backend"), False): - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" - logger.warning("Enabled the experimental CUDA malloc backend.") +async def entrypoint_async(): + """Async entry function for program startup""" network_config = config.network_config() @@ -97,7 +54,7 @@ async def entrypoint(args: Optional[dict] = None): load_auth_keys(unwrap(network_config.get("disable_auth"), False)) # Override the generation log options if given - log_config = config.gen_logging_config() + log_config = config.logging_config() if log_config: gen_logging.update_from_dict(log_config) @@ -128,8 +85,98 @@ async def entrypoint(args: Optional[dict] = None): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) await model.container.load_loras(lora_dir.resolve(), **lora_config) + # If an initial embedding model name is specified, create a separate container + # and load the model + embedding_config = config.embeddings_config() + embedding_model_name = embedding_config.get("embedding_model_name") + if embedding_model_name: + embedding_model_path = pathlib.Path( + unwrap(embedding_config.get("embedding_model_dir"), "models") + ) + embedding_model_path = embedding_model_path / embedding_model_name + + try: + await model.load_embedding_model(embedding_model_path, **embedding_config) + except ImportError as ex: + logger.error(ex.msg) + await start_api(host, port) +def entrypoint(arguments: Optional[dict] = None): + setup_logger() + + # Set up signal aborting + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Load from YAML config + config.from_file(pathlib.Path("config.yml")) + + # Parse and override config from args + if arguments is None: + parser = init_argparser() + arguments = convert_args_to_dict(parser.parse_args(), parser) + + config.from_args(arguments) + + if do_export_openapi: + openapi_json = export_openapi() + + with open("openapi.json", "w") as f: + f.write(json.dumps(openapi_json)) + logger.info("Successfully wrote OpenAPI spec to openapi.json") + + return + + developer_config = config.developer_config() + + # Check exllamav2 version and give a descriptive error if it's too old + # Skip if launching unsafely + + if unwrap(developer_config.get("unsafe_launch"), False): + logger.warning( + "UNSAFE: Skipping ExllamaV2 version check.\n" + "If you aren't a developer, please keep this off!" + ) + else: + check_exllama_version() + + # Enable CUDA malloc backend + if unwrap(developer_config.get("cuda_malloc_backend"), False): + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" + logger.warning("EXPERIMENTAL: Enabled the pytorch CUDA malloc backend.") + + # Use Uvloop/Winloop + if unwrap(developer_config.get("uvloop"), False): + if platform.system() == "Windows": + from winloop import install + else: + from uvloop import install + + # Set loop event policy + install() + + logger.warning("EXPERIMENTAL: Running program with Uvloop/Winloop.") + + # Set the process priority + if unwrap(developer_config.get("realtime_process_priority"), False): + import psutil + + current_process = psutil.Process(os.getpid()) + if platform.system() == "Windows": + current_process.nice(psutil.REALTIME_PRIORITY_CLASS) + else: + current_process.nice(psutil.IOPRIO_CLASS_RT) + + logger.warning( + "EXPERIMENTAL: Process priority set to Realtime. \n" + "If you're not running on administrator/sudo, the priority is set to high." + ) + + # Enter into the async event loop + asyncio.run(entrypoint_async()) + + if __name__ == "__main__": - asyncio.run(entrypoint()) + entrypoint() diff --git a/pyproject.toml b/pyproject.toml index 9d9aaf8..e591a15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ version = "0.0.1" description = "An OAI compatible exllamav2 API that's both lightweight and fast" requires-python = ">=3.10" dependencies = [ - "fastapi >= 0.110.0", + "fastapi-slim >= 0.110.0", "pydantic >= 2.0.0", "PyYAML", "rich", @@ -28,13 +28,21 @@ dependencies = [ "tokenizers", "lm-format-enforcer >= 0.9.6", "aiofiles", + "aiohttp", + "huggingface_hub", + "psutil", + "httptools>=0.5.0", + + # Improved asyncio loops + "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'", + "winloop ; platform_system == 'Windows'", # TEMP: Remove once 2.x is fixed in upstream "numpy < 2.0.0", - # TODO: Maybe move these to a downloader feature? - "aiohttp", - "huggingface_hub", + # For python 3.12 + "fastparquet @ https://github.com/theroyallab/fastparquet/releases/download/v2024.5.0/fastparquet-0.1.dev837-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "setuptools ; python_version == '3.12'" ] [project.urls] @@ -43,7 +51,9 @@ dependencies = [ [project.optional-dependencies] extras = [ # Heavy dependencies that aren't for everyday use - "outlines" + "outlines", + "infinity-emb", + "sentence-transformers", ] dev = [ "ruff == 0.3.2" @@ -58,22 +68,22 @@ cu121 = [ "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] cu118 = [ # Torch @@ -85,17 +95,17 @@ cu118 = [ "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", ] amd = [ # Torch triton for ROCm @@ -109,9 +119,9 @@ amd = [ "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.6/exllamav2-0.1.6+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options diff --git a/start.bat b/start.bat index c7b8330..4ae3247 100644 --- a/start.bat +++ b/start.bat @@ -19,3 +19,5 @@ if exist "%CONDA_PREFIX%" ( :: Call the python script with batch args call python start.py %* + +pause diff --git a/start.py b/start.py index d1f8843..490570e 100644 --- a/start.py +++ b/start.py @@ -1,17 +1,21 @@ """Utility to automatically upgrade and start the API""" -import asyncio import argparse +import json import os import pathlib import platform import subprocess import sys from shutil import copyfile +import traceback from common.args import convert_args_to_dict, init_argparser +start_options = {} + + def get_user_choice(question: str, options_dict: dict): """ Gets user input in a commandline script. @@ -40,36 +44,24 @@ def get_install_features(lib_name: str = None): """Fetches the appropriate requirements file depending on the GPU""" install_features = None possible_features = ["cu121", "cu118", "amd"] - saved_lib_path = pathlib.Path("gpu_lib.txt") - if lib_name: - print("Overriding GPU lib name from args.") - else: - # Try getting the GPU lib from file - if saved_lib_path.exists(): - with open(saved_lib_path.resolve(), "r") as f: - lib_name = f.readline().strip() - else: - # Ask the user for the GPU lib - gpu_lib_choices = { - "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, - "B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, - "C": {"pretty": "AMD", "internal": "amd"}, - } - user_input = get_user_choice( - "Select your GPU. If you don't know, select Cuda 12.x (A)", - gpu_lib_choices, - ) + if not lib_name: + # Ask the user for the GPU lib + gpu_lib_choices = { + "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, + "B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, + "C": {"pretty": "AMD", "internal": "amd"}, + } + user_input = get_user_choice( + "Select your GPU. If you don't know, select Cuda 12.x (A)", + gpu_lib_choices, + ) - lib_name = gpu_lib_choices.get(user_input, {}).get("internal") + lib_name = gpu_lib_choices.get(user_input, {}).get("internal") - # Write to a file for subsequent runs - with open(saved_lib_path.resolve(), "w") as f: - f.write(lib_name) - print( - "Saving your choice to gpu_lib.txt. " - "Delete this file and restart if you want to change your selection." - ) + # Write to start options + start_options["gpu_lib"] = lib_name + print("Saving your choice to start options.") # Assume default if the file is invalid if lib_name and lib_name in possible_features: @@ -79,7 +71,7 @@ def get_install_features(lib_name: str = None): print( f"WARN: GPU library {lib_name} not found. " "Skipping GPU-specific dependencies.\n" - "WARN: Please delete gpu_lib.txt and restart " + "WARN: Please remove the `gpu_lib` key from start_options.json and restart " "if you want to change your selection." ) return @@ -105,10 +97,22 @@ def add_start_args(parser: argparse.ArgumentParser): """Add start script args to the provided parser""" start_group = parser.add_argument_group("start") start_group.add_argument( - "-iu", - "--ignore-upgrade", + "-ur", + "--update-repository", action="store_true", - help="Ignore requirements upgrade", + help="Update local git repository to latest", + ) + start_group.add_argument( + "-ud", + "--update-deps", + action="store_true", + help="Update all pip dependencies", + ) + start_group.add_argument( + "-fr", + "--force-reinstall", + action="store_true", + help="Forces a reinstall of dependencies. Only works with --update-deps", ) start_group.add_argument( "-nw", @@ -123,6 +127,26 @@ def add_start_args(parser: argparse.ArgumentParser): ) +def migrate_gpu_lib(): + gpu_lib_path = pathlib.Path("gpu_lib.txt") + + if not gpu_lib_path.exists(): + return + + print("Migrating gpu_lib.txt to the new start_options.json") + with open("gpu_lib.txt", "r") as gpu_lib_file: + start_options["gpu_lib"] = gpu_lib_file.readline().strip() + start_options["first_run_done"] = True + + # Remove the old file + gpu_lib_path.unlink() + + print( + "Successfully migrated gpu lib options to start_options. " + "The old file has been deleted." + ) + + if __name__ == "__main__": subprocess.run(["pip", "-V"]) @@ -130,6 +154,35 @@ if __name__ == "__main__": parser = init_argparser() add_start_args(parser) args = parser.parse_args() + script_ext = "bat" if platform.system() == "Windows" else "sh" + + start_options_path = pathlib.Path("start_options.json") + if start_options_path.exists(): + with open(start_options_path) as start_options_file: + start_options = json.load(start_options_file) + print("Loaded your saved preferences from `start_options.json`") + + if start_options.get("first_run_done"): + first_run = False + else: + print( + "It looks like you're running TabbyAPI for the first time. " + "Getting things ready..." + ) + + # Migrate from old setting storage + migrate_gpu_lib() + + # Set variables that rely on start options + first_run = not start_options.get("first_run_done") + + if args.gpu_lib: + print("Overriding GPU lib name from args.") + gpu_lib = args.gpu_lib + elif "gpu_lib" in start_options: + gpu_lib = start_options.get("gpu_lib") + else: + gpu_lib = None # Create a config if it doesn't exist # This is not necessary to run TabbyAPI, but is new user proof @@ -145,18 +198,65 @@ if __name__ == "__main__": f"Created one at {str(config_path.resolve())}" ) - if args.ignore_upgrade: - print("Ignoring pip dependency upgrade due to user request.") - else: - install_features = None if args.nowheel else get_install_features(args.gpu_lib) - features = f"[{install_features}]" if install_features else "" + if args.update_repository: + print("Pulling latest changes from Github.") + pull_command = "git pull" + subprocess.run(pull_command.split(" ")) + + if first_run or args.update_deps: + install_command = ["pip", "install", "-U"] + + # Force a reinstall of the updated dependency if needed + if args.force_reinstall: + install_command.append("--force-reinstall") + + install_features = None if args.nowheel else get_install_features(gpu_lib) + features = f".[{install_features}]" if install_features else "." + install_command.append(features) # pip install .[features] - install_command = f"pip install -U .{features}" - print(f"Running install command: {install_command}") - subprocess.run(install_command.split(" ")) + print(f"Running install command: {' '.join(install_command)}") + subprocess.run(install_command) + print() + + if args.update_deps: + print( + f"Dependencies updated. Please run TabbyAPI with `start.{script_ext}`. " + "Exiting." + ) + sys.exit(0) + else: + print( + f"Dependencies installed. Update them with `update_deps.{script_ext}` " + "inside the `update_scripts` folder." + ) + + if first_run: + start_options["first_run_done"] = True + + # Save start options + with open("start_options.json", "w") as start_file: + start_file.write(json.dumps(start_options)) + + print( + "Successfully wrote your start script options to `start_options.json`. \n" + "If something goes wrong, editing or deleting the file " + "will reinstall TabbyAPI as a first-time user." + ) # Import entrypoint after installing all requirements - from main import entrypoint + try: + from main import entrypoint - asyncio.run(entrypoint(convert_args_to_dict(args, parser))) + converted_args = convert_args_to_dict(args, parser) + + print("Starting TabbyAPI...") + entrypoint(converted_args) + except (ModuleNotFoundError, ImportError): + traceback.print_exc() + print( + "\n" + "This error was raised because a package was not found.\n" + "Update your dependencies by running update_scripts/" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'}\n\n" + ) diff --git a/update_scripts/update_deps.bat b/update_scripts/update_deps.bat new file mode 100644 index 0000000..e03c827 --- /dev/null +++ b/update_scripts/update_deps.bat @@ -0,0 +1,24 @@ +@echo off + +:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades +:: This is intended for users who want to start the API and have everything upgraded and installed + +:: cd to the parent directory +cd "%~dp0.." + +:: Don't create a venv if a conda environment is active +if exist "%CONDA_PREFIX%" ( + echo It looks like you're in a conda environment. Skipping venv check. +) else ( + if not exist "venv\" ( + echo Venv doesn't exist! Please run start.bat instead. + exit 0 + ) + + call .\venv\Scripts\activate.bat +) + +:: Call the python script with batch args +call python start.py --update-deps %* + +pause diff --git a/update_scripts/update_deps.sh b/update_scripts/update_deps.sh new file mode 100755 index 0000000..becfa49 --- /dev/null +++ b/update_scripts/update_deps.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +cd "$(dirname "$0")/.." || exit + +if [ -n "$CONDA_PREFIX" ]; then + echo "It looks like you're in a conda environment. Skipping venv check." +else + if [ ! -d "venv" ]; then + echo "Venv doesn't exist! Please run start.sh instead." + exit 0 + fi + + echo "Activating venv" + + # shellcheck source=/dev/null + source venv/bin/activate +fi + +python3 start.py --update-deps "$@" diff --git a/update_scripts/update_deps_and_pull.bat b/update_scripts/update_deps_and_pull.bat new file mode 100644 index 0000000..c22866b --- /dev/null +++ b/update_scripts/update_deps_and_pull.bat @@ -0,0 +1,24 @@ +@echo off + +:: Creates a venv if it doesn't exist and runs the start script for requirements upgrades +:: This is intended for users who want to start the API and have everything upgraded and installed + +:: cd to the parent directory +cd "%~dp0.." + +:: Don't create a venv if a conda environment is active +if exist "%CONDA_PREFIX%" ( + echo It looks like you're in a conda environment. Skipping venv check. +) else ( + if not exist "venv\" ( + echo Venv doesn't exist! Please run start.bat instead. + exit 0 + ) + + call .\venv\Scripts\activate.bat +) + +:: Call the python script with batch args +call python start.py --update-deps --update-repository %* + +pause diff --git a/update_scripts/update_deps_and_pull.sh b/update_scripts/update_deps_and_pull.sh new file mode 100755 index 0000000..4582cc5 --- /dev/null +++ b/update_scripts/update_deps_and_pull.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +cd "$(dirname "$0")/.." || exit + +if [ -n "$CONDA_PREFIX" ]; then + echo "It looks like you're in a conda environment. Skipping venv check." +else + if [ ! -d "venv" ]; then + echo "Venv doesn't exist! Please run start.sh instead." + exit 0 + fi + + echo "Activating venv" + + # shellcheck source=/dev/null + source venv/bin/activate +fi + +python3 start.py --update-deps --update-repository "$@"