diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py new file mode 100644 index 0000000..ac5ce1a --- /dev/null +++ b/backends/exllamav2/utils.py @@ -0,0 +1,31 @@ +from packaging import version +from importlib.metadata import version as package_version + +from common.logger import init_logger + +logger = init_logger(__name__) + + +def check_exllama_version(): + """Verifies the exllama version""" + + required_version = "0.0.12" + current_version = package_version("exllamav2").split("+")[0] + + if version.parse(current_version) < version.parse(required_version): + raise SystemExit( + f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " + f"or greater. Your current version is {current_version}.\n" + "Please upgrade your environment by running a start script " + "(start.bat or start.sh)\n\n" + "Or you can manually run a requirements update " + "using the following command:\n\n" + "For CUDA 12.1:\n" + "pip install --upgrade -r requirements.txt\n\n" + "For CUDA 11.8:\n" + "pip install --upgrade -r requirements-cu118.txt\n\n" + "For ROCm:\n" + "pip install --upgrade -r requirements-amd.txt\n\n" + ) + else: + logger.info(f"ExllamaV2 version: {current_version}") diff --git a/common/args.py b/common/args.py index 1aee327..0696c4a 100644 --- a/common/args.py +++ b/common/args.py @@ -125,3 +125,12 @@ def add_logging_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Enable generation parameter logging", ) + + +def add_developer_args(parser: argparse.ArgumentParser): + """Adds developer-specific arguments""" + + developer_group = parser.add_argument_group("developer") + developer_group.add_argument( + "--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check" + ) diff --git a/common/config.py b/common/config.py index 9a4b7b1..f02e48d 100644 --- a/common/config.py +++ b/common/config.py @@ -55,6 +55,11 @@ def override_config_from_args(args: dict): **{k.replace("log_", ""): logging_override[k] for k in logging_override}, } + developer_override = args.get("developer") + if developer_override: + developer_config = get_developer_config() + GLOBAL_CONFIG["developer"] = {**developer_config, **developer_override} + def get_sampling_config(): """Returns the sampling parameter config from the global config""" @@ -86,3 +91,8 @@ def get_network_config(): def get_gen_logging_config(): """Returns the generation logging config from the global config""" return unwrap(GLOBAL_CONFIG.get("logging"), {}) + + +def get_developer_config(): + """Returns the developer specific config from the global config""" + return unwrap(GLOBAL_CONFIG.get("developer"), {}) diff --git a/config_sample.yml b/config_sample.yml index cf1ddb5..3e6f60f 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -35,6 +35,13 @@ sampling: # WARNING: Using this can result in a generation speed penalty #override_preset: +# Options for development +developer: + # Skips exllamav2 version check (default: False) + # It's highly recommended to update your dependencies rather than enabling this flag + # WARNING: Don't set this unless you know what you're doing! + #unsafe_launch: False + # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) diff --git a/main.py b/main.py index 921ebbe..c47a9fc 100644 --- a/main.py +++ b/main.py @@ -9,15 +9,15 @@ from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from functools import partial -from packaging import version -from importlib.metadata import version as package_version from progress.bar import IncrementalBar import common.gen_logging as gen_logging from backends.exllamav2.model import ExllamaV2Container +from backends.exllamav2.utils import check_exllama_version from common.args import convert_args_to_dict, init_argparser from common.auth import check_admin_key, check_api_key, load_auth_keys from common.config import ( + get_developer_config, get_sampling_config, override_config_from_args, read_config_from_file, @@ -580,26 +580,6 @@ def entrypoint(args: Optional[dict] = None): """Entry function for program startup""" global MODEL_CONTAINER - # Check exllamav2 version and give a descriptive error if it's too old - required_exl_version = "0.0.12" - current_exl_version = package_version("exllamav2").split("+")[0] - - if version.parse(current_exl_version) < version.parse(required_exl_version): - raise SystemExit( - f"TabbyAPI requires ExLlamaV2 {required_exl_version} " - f"or greater. Your current version is {current_exl_version}.\n" - "Please upgrade your environment by running a start script " - "(start.bat or start.sh)\n\n" - "Or you can manually run a requirements update " - "using the following command:\n\n" - "For CUDA 12.1:\n" - "pip install --upgrade -r requirements.txt\n\n" - "For CUDA 11.8:\n" - "pip install --upgrade -r requirements-cu118.txt\n\n" - "For ROCm:\n" - "pip install --upgrade -r requirements-amd.txt\n\n" - ) - # Load from YAML config read_config_from_file(pathlib.Path("config.yml")) @@ -610,6 +590,19 @@ def entrypoint(args: Optional[dict] = None): override_config_from_args(args) + developer_config = get_developer_config() + + # Check exllamav2 version and give a descriptive error if it's too old + # Skip if launching unsafely + + if unwrap(developer_config.get("unsafe_launch"), False): + logger.warning( + "UNSAFE: Skipping ExllamaV2 version check.\n" + "If you aren't a developer, please keep this off!" + ) + else: + check_exllama_version() + network_config = get_network_config() # Initialize auth keys