diff --git a/.gitignore b/.gitignore index 90f8fbb..effb432 100644 --- a/.gitignore +++ b/.gitignore @@ -173,4 +173,7 @@ poetry.toml # LSP config files pyrightconfig.json -# End of https://www.toptal.com/developers/gitignore/api/python \ No newline at end of file +# End of https://www.toptal.com/developers/gitignore/api/python + +# User configuration +config.yml diff --git a/config_sample.yml b/config_sample.yml new file mode 100644 index 0000000..c18197d --- /dev/null +++ b/config_sample.yml @@ -0,0 +1,8 @@ +model_dir: "D:/models" +model_name: "this_is_a_exl2_model" +max_seq_len: 4096 +gpu_split: "auto" +rope_scale: 1.0 +rope_alpha: 1.0 +no_flash_attention: False +low_mem: False diff --git a/main.py b/main.py index 7e11f3d..b660547 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,9 @@ -import os -import argparse import uvicorn +import yaml from fastapi import FastAPI, HTTPException from pydantic import BaseModel from model import ModelContainer -from utils import add_args +from progress.bar import IncrementalBar app = FastAPI() @@ -38,22 +37,30 @@ def generate_text(request: TextRequest): except RuntimeError as e: raise HTTPException(status_code=500, detail=str(e)) -# Debug progress check -def progress(module, modules): - print(f"Loaded {module}/{modules} modules") - yield +# Wrapper callback for load progress +def load_progress(module, modules): + yield module, modules if __name__ == "__main__": - # Convert this parser to use a YAML config - parser = argparse.ArgumentParser(description = "TabbyAPI - An API server for exllamav2") - add_args(parser) - args = parser.parse_args() + # Load from YAML config. Possibly add a config -> kwargs conversion function + with open('config.yml', 'r') as config_file: + config = yaml.safe_load(config_file) + + # If an initial model name is specified, create a container and load the model + if config["model_name"]: + model_path = f"{config['model_dir']}/{config['model_name']}" if config['model_dir'] else f"models/{config['model_name']}" + + model_container = ModelContainer(model_path, False, **config) + load_status = model_container.load_gen(load_progress) + for (module, modules) in load_status: + if module == 0: + loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + else: + loading_bar.next() + + if module == modules: + loading_bar.finish() - # If an initial model dir is specified, create a container and load the model - if args.model_dir: - model_container = ModelContainer(args.model_dir, False, **vars(args)) - print("Loading an initial model...") - model_container.load(progress) print("Model successfully loaded.") # Reload is for dev purposes ONLY! diff --git a/model.py b/model.py index 0285faa..4e420a2 100644 --- a/model.py +++ b/model.py @@ -34,7 +34,6 @@ class ModelContainer: gpu_split: list or None = None def __init__(self, model_directory: str, quiet = False, **kwargs): - print(kwargs) """ Create model container @@ -76,6 +75,9 @@ class ModelContainer: if "rope_alpha" in kwargs: self.config.scale_alpha_value = kwargs["rope_alpha"] if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] + if "low_mem" in kwargs and kwargs["low_mem"]: + self.config.set_low_mem() + chunk_size = min(kwargs.get("chunk_size", 2048), self.config.max_seq_len) self.config.max_input_len = chunk_size self.config.max_attn_size = chunk_size ** 2 diff --git a/requirements.txt b/requirements.txt index 56744dc..90eaff9 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/utils.py b/utils.py deleted file mode 100644 index 9d82216..0000000 --- a/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -def add_args(parser): - parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory") - parser.add_argument("-gs", "--gpu_split", type = str, help = "\"auto\", or VRAM allocation per GPU in GB") - parser.add_argument("-l", "--max_seq_len", type = int, help = "Maximum sequence length") - parser.add_argument("-rs", "--rope_scale", type = float, default = 1.0, help = "RoPE scaling factor") - parser.add_argument("-ra", "--rope_alpha", type = float, default = 1.0, help = "RoPE alpha value (NTK)") - parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention") - parser.add_argument("-lm", "--low_mem", action = "store_true", help = "Enable VRAM optimizations, potentially trading off speed")