From 78a6587b95cf53b4964ef34315af102c834aef27 Mon Sep 17 00:00:00 2001 From: waldfee Date: Fri, 17 Nov 2023 22:08:31 +0100 Subject: [PATCH 1/2] add cache_mode and draft_model_dir to config_sample.yml --- config_sample.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/config_sample.yml b/config_sample.yml index 986990c..466329b 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -20,6 +20,10 @@ model: # A model can be loaded later via the API. This does not have to be specified # model_name: A model name + # Set the following to enable speculative decoding + # draft_model_dir: your model directory path to use as draft model (path is independent from model_dir) + # draft_rope_alpha: 1.0 (default: the draft model's alpha value is calculated automatically to scale to the size of the full model.) + # The below parameters apply only if model_name is set # Maximum model context length (default: 4096) @@ -40,3 +44,6 @@ model: # Enable low vram optimizations in exllamav2 (default: False) low_mem: False + + # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) + # cache_mode: "FP8" \ No newline at end of file From 27ebec3b35ce39712c9368839afb9a92306fb916 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 18 Nov 2023 01:38:54 -0500 Subject: [PATCH 2/2] Model: Add speculative decoding support via config Speculative decoding makes use of draft models that ingest the prompt before forwarding it to the main model. Add options in the config to support this. API options will occur in a different commit. Signed-off-by: kingbri --- config_sample.yml | 18 +++++++++++++++--- main.py | 16 ++++++++-------- model.py | 18 +++++++++++++++--- 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/config_sample.yml b/config_sample.yml index 466329b..2fb0225 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -13,11 +13,11 @@ network: # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) - # Windows users: DO NOT put this path in quotes! This directory will be invalid otherwise. + # Windows users, DO NOT put this path in quotes! This directory will be invalid otherwise. # model_dir: your model directory path # An initial model to load. Make sure the model is located in the model directory! - # A model can be loaded later via the API. This does not have to be specified + # A model can be loaded later via the API. # model_name: A model name # Set the following to enable speculative decoding @@ -46,4 +46,16 @@ model: low_mem: False # Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16) - # cache_mode: "FP8" \ No newline at end of file + # cache_mode: FP16 + + # Options for draft models (speculative decoding). This will use more VRAM! + # draft: + # Overrides the directory to look for draft (default: models) + # draft_model_dir: Your draft model directory path + + # An initial draft model to load. Make sure this model is located in the model directory! + # A draft model can be loaded later via the API. + # draft_model_name: A model name + + # Rope parameters for draft models (default: 1.0) + # draft_rope_alpha: 1.0 diff --git a/main.py b/main.py index 05c880e..747c120 100644 --- a/main.py +++ b/main.py @@ -56,7 +56,7 @@ async def list_models(): else: model_path = pathlib.Path("models") - models = get_model_list(model_path) + models = get_model_list(model_path.resolve()) return models @@ -76,7 +76,7 @@ async def load_model(data: ModelLoadRequest): def generator(): global model_container - model_config = config.get("model", {}) + model_config = config.get("model") or {} if "model_dir" in model_config: model_path = pathlib.Path(model_config["model_dir"]) else: @@ -84,7 +84,7 @@ async def load_model(data: ModelLoadRequest): model_path = model_path / data.name - model_container = ModelContainer(model_path, False, **data.dict()) + model_container = ModelContainer(model_path.resolve(), False, **data.dict()) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: if module == 0: @@ -217,12 +217,12 @@ if __name__ == "__main__": # If an initial model name is specified, create a container and load the model - model_config = config.get("model", {}) + model_config = config.get("model") or {} if "model_name" in model_config: - model_path = pathlib.Path(model_config.get("model_dir", "models")) - model_path = model_path / model_config["model_name"] + model_path = pathlib.Path(model_config.get("model_dir") or "models") + model_path = model_path / model_config.get("model_name") - model_container = ModelContainer(model_path, False, **model_config) + model_container = ModelContainer(model_path.resolve(), False, **model_config) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: if module == 0: @@ -233,7 +233,7 @@ if __name__ == "__main__": else: loading_bar.next() - network_config = config.get("network", {}) + network_config = config.get("network") or {} uvicorn.run( app, host=network_config.get("host", "127.0.0.1"), diff --git a/model.py b/model.py index 8aa1872..536fb66 100644 --- a/model.py +++ b/model.py @@ -82,17 +82,29 @@ class ModelContainer: self.config.max_input_len = chunk_size self.config.max_attn_size = chunk_size ** 2 - self.draft_enabled = "draft_model_dir" in kwargs + draft_config = kwargs.get("draft") or {} + draft_model_name = draft_config.get("draft_model_name") + enable_draft = bool(draft_config) and draft_model_name is not None + + if bool(draft_config) and draft_model_name is None: + print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.") + self.draft_enabled = False + else: + self.draft_enabled = enable_draft + if self.draft_enabled: self.draft_config = ExLlamaV2Config() - self.draft_config.model_dir = kwargs["draft_model_dir"] + draft_model_path = pathlib.Path(kwargs.get("draft_model_dir") or "models") + draft_model_path = draft_model_path / draft_model_name + + self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() self.draft_config.max_seq_len = self.config.max_seq_len if "draft_rope_alpha" in kwargs: - self.draft_config.scale_alpha_value = kwargs["draft_rope_alpha"] + self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1 else: ratio = self.config.max_seq_len / self.draft_config.max_seq_len alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2