From 1a331afe3a3d25920651e66d70bb5322184a7d6e Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 16 Dec 2023 02:42:36 -0500 Subject: [PATCH] OAI: Add cache_mode parameter to model Mistakenly forgot that the user can choose what cache mode to use when loading a model. Also add when fetching model info. Signed-off-by: kingbri --- OAI/types/model.py | 2 ++ main.py | 1 + 2 files changed, 3 insertions(+) diff --git a/OAI/types/model.py b/OAI/types/model.py index 9ba252f..f8922d9 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -8,6 +8,7 @@ class ModelCardParameters(BaseModel): rope_scale: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0 prompt_template: Optional[str] = None + cache_mode: Optional[str] = "FP16" draft: Optional['ModelCard'] = None class ModelCard(BaseModel): @@ -37,6 +38,7 @@ class ModelLoadRequest(BaseModel): rope_alpha: Optional[float] = 1.0 no_flash_attention: Optional[bool] = False # low_mem: Optional[bool] = False + cache_mode: Optional[str] = "FP16" prompt_template: Optional[str] = None draft: Optional[DraftModelLoadRequest] = None diff --git a/main.py b/main.py index 6a092c0..b204be7 100644 --- a/main.py +++ b/main.py @@ -82,6 +82,7 @@ async def get_current_model(): rope_scale = model_container.config.scale_pos_emb, rope_alpha = model_container.config.scale_alpha_value, max_seq_len = model_container.config.max_seq_len, + cache_mode = "FP8" if model_container.cache_fp8 else "FP16", prompt_template = unwrap(model_container.prompt_template, "auto") ), logging = gen_logging.config