diff --git a/OAI/types/model.py b/OAI/types/model.py index 9ba252f..f8922d9 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -8,6 +8,7 @@ class ModelCardParameters(BaseModel): rope_scale: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0 prompt_template: Optional[str] = None + cache_mode: Optional[str] = "FP16" draft: Optional['ModelCard'] = None class ModelCard(BaseModel): @@ -37,6 +38,7 @@ class ModelLoadRequest(BaseModel): rope_alpha: Optional[float] = 1.0 no_flash_attention: Optional[bool] = False # low_mem: Optional[bool] = False + cache_mode: Optional[str] = "FP16" prompt_template: Optional[str] = None draft: Optional[DraftModelLoadRequest] = None diff --git a/main.py b/main.py index 6a092c0..b204be7 100644 --- a/main.py +++ b/main.py @@ -82,6 +82,7 @@ async def get_current_model(): rope_scale = model_container.config.scale_pos_emb, rope_alpha = model_container.config.scale_alpha_value, max_seq_len = model_container.config.max_seq_len, + cache_mode = "FP8" if model_container.cache_fp8 else "FP16", prompt_template = unwrap(model_container.prompt_template, "auto") ), logging = gen_logging.config