diff --git a/config_sample.yml b/config_sample.yml index 986990c..2fb0225 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -13,13 +13,17 @@ network: # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) - # Windows users: DO NOT put this path in quotes! This directory will be invalid otherwise. + # Windows users, DO NOT put this path in quotes! This directory will be invalid otherwise. # model_dir: your model directory path # An initial model to load. Make sure the model is located in the model directory! - # A model can be loaded later via the API. This does not have to be specified + # A model can be loaded later via the API. # model_name: A model name + # Set the following to enable speculative decoding + # draft_model_dir: your model directory path to use as draft model (path is independent from model_dir) + # draft_rope_alpha: 1.0 (default: the draft model's alpha value is calculated automatically to scale to the size of the full model.) + # The below parameters apply only if model_name is set # Maximum model context length (default: 4096) @@ -40,3 +44,18 @@ model: # Enable low vram optimizations in exllamav2 (default: False) low_mem: False + + # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) + # cache_mode: FP16 + + # Options for draft models (speculative decoding). This will use more VRAM! + # draft: + # Overrides the directory to look for draft (default: models) + # draft_model_dir: Your draft model directory path + + # An initial draft model to load. Make sure this model is located in the model directory! + # A draft model can be loaded later via the API. + # draft_model_name: A model name + + # Rope parameters for draft models (default: 1.0) + # draft_rope_alpha: 1.0 diff --git a/main.py b/main.py index 05c880e..747c120 100644 --- a/main.py +++ b/main.py @@ -56,7 +56,7 @@ async def list_models(): else: model_path = pathlib.Path("models") - models = get_model_list(model_path) + models = get_model_list(model_path.resolve()) return models @@ -76,7 +76,7 @@ async def load_model(data: ModelLoadRequest): def generator(): global model_container - model_config = config.get("model", {}) + model_config = config.get("model") or {} if "model_dir" in model_config: model_path = pathlib.Path(model_config["model_dir"]) else: @@ -84,7 +84,7 @@ async def load_model(data: ModelLoadRequest): model_path = model_path / data.name - model_container = ModelContainer(model_path, False, **data.dict()) + model_container = ModelContainer(model_path.resolve(), False, **data.dict()) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: if module == 0: @@ -217,12 +217,12 @@ if __name__ == "__main__": # If an initial model name is specified, create a container and load the model - model_config = config.get("model", {}) + model_config = config.get("model") or {} if "model_name" in model_config: - model_path = pathlib.Path(model_config.get("model_dir", "models")) - model_path = model_path / model_config["model_name"] + model_path = pathlib.Path(model_config.get("model_dir") or "models") + model_path = model_path / model_config.get("model_name") - model_container = ModelContainer(model_path, False, **model_config) + model_container = ModelContainer(model_path.resolve(), False, **model_config) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: if module == 0: @@ -233,7 +233,7 @@ if __name__ == "__main__": else: loading_bar.next() - network_config = config.get("network", {}) + network_config = config.get("network") or {} uvicorn.run( app, host=network_config.get("host", "127.0.0.1"), diff --git a/model.py b/model.py index fd9a51d..a10b69f 100644 --- a/model.py +++ b/model.py @@ -82,17 +82,29 @@ class ModelContainer: self.config.max_input_len = chunk_size self.config.max_attn_size = chunk_size ** 2 - self.draft_enabled = "draft_model_dir" in kwargs + draft_config = kwargs.get("draft") or {} + draft_model_name = draft_config.get("draft_model_name") + enable_draft = bool(draft_config) and draft_model_name is not None + + if bool(draft_config) and draft_model_name is None: + print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.") + self.draft_enabled = False + else: + self.draft_enabled = enable_draft + if self.draft_enabled: self.draft_config = ExLlamaV2Config() - self.draft_config.model_dir = kwargs["draft_model_dir"] + draft_model_path = pathlib.Path(kwargs.get("draft_model_dir") or "models") + draft_model_path = draft_model_path / draft_model_name + + self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() self.draft_config.max_seq_len = self.config.max_seq_len if "draft_rope_alpha" in kwargs: - self.draft_config.scale_alpha_value = kwargs["draft_rope_alpha"] + self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1 else: ratio = self.config.max_seq_len / self.draft_config.max_seq_len alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2