From eee8b642bd21455d7169a4ceb8c044f56b5308ad Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 13 Nov 2023 18:24:12 -0500 Subject: [PATCH] OAI: Implement completion API endpoint Add support for /v1/completions with the option to use streaming if needed. Also rewrite API endpoints to use async when possible since that improves request performance. Model container parameter names also needed rewrites as well and set fallback cases to their disabled values. Signed-off-by: kingbri --- OAI/models/common.py | 13 +++++ OAI/models/completions.py | 99 ++++++++++++++++++++++++++++++++++++++ OAI/utils.py | 19 ++++++++ main.py | 51 +++++++++----------- model.py | 65 ++++++++++++++----------- requirements.txt | Bin 50 -> 78 bytes 6 files changed, 190 insertions(+), 57 deletions(-) create mode 100644 OAI/models/common.py create mode 100644 OAI/models/completions.py create mode 100644 OAI/utils.py diff --git a/OAI/models/common.py b/OAI/models/common.py new file mode 100644 index 0000000..a86341d --- /dev/null +++ b/OAI/models/common.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from typing import List, Dict + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[float] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Dict[str, float]] = Field(default_factory=list) + +class UsageStats(BaseModel): + completion_tokens: int + prompt_tokens: int + total_tokens: int diff --git a/OAI/models/completions.py b/OAI/models/completions.py new file mode 100644 index 0000000..ba107f6 --- /dev/null +++ b/OAI/models/completions.py @@ -0,0 +1,99 @@ +from uuid import uuid4 +from time import time +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Union +from OAI.models.common import LogProbs, UsageStats + +class CompletionRespChoice(BaseModel): + finish_reason: str + index: int + logprobs: Optional[LogProbs] = None + text: str + +class CompletionRequest(BaseModel): + # Model information + model: str + + # Prompt can also contain token ids, but that's out of scope for this project. + prompt: Union[str, List[str]] + + # Extra OAI request stuff + best_of: Optional[int] = None + echo: Optional[bool] = False + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + n: Optional[int] = 1 + suffix: Optional[str] = None + user: Optional[str] = None + + # Generation info + seed: Optional[int] = -1 + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + + # Default to 150 as 16 makes no sense as a default + max_tokens: Optional[int] = 150 + + # Not supported sampling params + presence_penalty: Optional[int] = 0 + + # Aliased to repetition_penalty + frequency_penalty: int = 0 + + # Sampling params + token_healing: Optional[bool] = False + temperature: Optional[float] = 1.0 + top_k: Optional[int] = 0 + top_p: Optional[float] = 1.0 + typical: Optional[float] = 0.0 + min_p: Optional[float] = 0.0 + tfs: Optional[float] = 1.0 + repetition_penalty: Optional[float] = 1.0 + repetition_penalty_range: Optional[int] = 0 + repetition_decay: Optional[int] = 0 + mirostat_mode: Optional[int] = 0 + mirostat_tau: Optional[float] = 1.5 + mirostat_eta: Optional[float] = 0.1 + + # Converts to internal generation parameters + def to_gen_params(self): + # Convert prompt to a string + if isinstance(self.prompt, list): + self.prompt = "\n".join(self.prompt) + + # Convert stop to an array of strings + if isinstance(self.stop, str): + self.stop = [self.stop] + + # Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined + if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty: + self.repetition_penalty = self.frequency_penalty + + return { + "prompt": self.prompt, + "stop": self.stop, + "max_tokens": self.max_tokens, + "token_healing": self.token_healing, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + "typical": self.typical, + "min_p": self.min_p, + "tfs": self.tfs, + "repetition_penalty": self.repetition_penalty, + "repetition_penalty_range": self.repetition_penalty_range, + "repetition_decay": self.repetition_decay, + "mirostat": True if self.mirostat_mode == 2 else False, + "mirostat_tau": self.mirostat_tau, + "mirostat_eta": self.mirostat_eta + } + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}") + choices: List[CompletionRespChoice] + created: int = Field(default_factory=lambda: int(time())) + model: str + object: str = "text-completion" + + # TODO: Add usage stats + usage: Optional[UsageStats] = None diff --git a/OAI/utils.py b/OAI/utils.py new file mode 100644 index 0000000..ebd80e3 --- /dev/null +++ b/OAI/utils.py @@ -0,0 +1,19 @@ +from OAI.models.completions import CompletionResponse, CompletionRespChoice +from OAI.models.common import UsageStats +from typing import Optional + +def create_completion_response(text: str, index: int, model_name: Optional[str]): + # TODO: Add method to get token amounts in model for UsageStats + + choice = CompletionRespChoice( + finish_reason="Generated", + index = index, + text = text + ) + + response = CompletionResponse( + choices = [choice], + model = model_name or "" + ) + + return response diff --git a/main.py b/main.py index b660547..d9b53e3 100644 --- a/main.py +++ b/main.py @@ -1,41 +1,37 @@ import uvicorn import yaml -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +from fastapi import FastAPI, Request from model import ModelContainer from progress.bar import IncrementalBar +from sse_starlette import EventSourceResponse +from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice +from OAI.utils import create_completion_response app = FastAPI() # Initialize a model container. This can be undefined at any period of time model_container: ModelContainer = None -class TextRequest(BaseModel): - model: str = None # Make the "model" field optional with a default value of None - prompt: str - max_tokens: int = 200 - temperature: float = 1 - top_p: float = 0.9 - seed: int = 10 - stream: bool = False - token_repetition_penalty: float = 1.0 - stop: list = None +@app.post("/v1/completions") +async def generate_completion(request: Request, data: CompletionRequest): + if data.stream: + async def generator(): + new_generation = model_container.generate_gen(**data.to_gen_params()) + for index, part in enumerate(new_generation): + if await request.is_disconnected(): + break -class TextResponse(BaseModel): - response: str - generation_time: str + response = create_completion_response(part, index, model_container.get_model_name()) + + yield response.model_dump_json() + + return EventSourceResponse(generator()) + else: + response_text = model_container.generate(**data.to_gen_params()) + response = create_completion_response(response_text, 0, model_container.get_model_name()) + + return response.model_dump_json() -# TODO: Currently broken -@app.post("/generate-text", response_model=TextResponse) -def generate_text(request: TextRequest): - global modelManager - try: - prompt = request.prompt # Get the prompt from the request - user_message = prompt # Assuming that prompt is equivalent to the user's message - output, generation_time = modelManager.generate_text(prompt=user_message) - return {"response": output, "generation_time": generation_time} - except RuntimeError as e: - raise HTTPException(status_code=500, detail=str(e)) # Wrapper callback for load progress def load_progress(module, modules): @@ -63,5 +59,4 @@ if __name__ == "__main__": print("Model successfully loaded.") - # Reload is for dev purposes ONLY! - uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True) + uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug") diff --git a/model.py b/model.py index fb222d0..b7368b0 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,5 @@ import gc, time import torch - from exllamav2 import( ExLlamaV2, ExLlamaV2Config, @@ -8,25 +7,26 @@ from exllamav2 import( ExLlamaV2Cache_8bit, ExLlamaV2Tokenizer, ) - from exllamav2.generator import( ExLlamaV2StreamingGenerator, ExLlamaV2Sampler ) +from os import path +from typing import Optional # Bytes to reserve on first device when loading with auto split auto_split_reserve_bytes = 96 * 1024**2 class ModelContainer: - config: ExLlamaV2Config or None = None - draft_config: ExLlamaV2Config or None = None - model: ExLlamaV2 or None = None - draft_model: ExLlamaV2 or None = None - cache: ExLlamaV2Cache or None = None - draft_cache: ExLlamaV2Cache or None = None - tokenizer: ExLlamaV2Tokenizer or None = None - generator: ExLlamaV2StreamingGenerator or None = None + config: Optional[ExLlamaV2Config] = None + draft_config: Optional[ExLlamaV2Config] = None + model: Optional[ExLlamaV2] = None + draft_model: Optional[ExLlamaV2] = None + cache: Optional[ExLlamaV2Cache] = None + draft_cache: Optional[ExLlamaV2Cache] = None + tokenizer: Optional[ExLlamaV2Tokenizer] = None + generator: Optional[ExLlamaV2StreamingGenerator] = None cache_fp8: bool = False draft_enabled: bool = False @@ -102,6 +102,11 @@ class ModelContainer: self.draft_config.max_input_len = kwargs["chunk_size"] self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 + def get_model_name(self): + if self.draft_enabled: + return path.basename(path.normpath(self.draft_config.model_dir)) + else: + return path.basename(path.normpath(self.config.model_dir)) def load(self, progress_callback = None): """ @@ -201,20 +206,20 @@ class ModelContainer: prompt (str): Input prompt **kwargs: 'token_healing' (bool): Use token healing (default: False) - 'temperature' (float): Sampling temperature (default: 0.8) - 'top_k' (int): Sampling top-K (default: 100) - 'top_p' (float): Sampling top-P (default: 0.8) + 'temperature' (float): Sampling temperature (default: 1.0) + 'top_k' (int): Sampling top-K (default: 0) + 'top_p' (float): Sampling top-P (default: 1.0) 'min_p' (float): Sampling min-P (default: 0.0) 'tfs' (float): Tail-free sampling (default: 0.0) 'typical' (float): Sampling typical (default: 0.0) 'mirostat' (bool): Use Mirostat (default: False) 'mirostat_tau' (float) Mirostat tau parameter (default: 1.5) 'mirostat_eta' (float) Mirostat eta parameter (default: 0.1) - 'token_repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) - 'token_repetition_range' (int): Repetition penalty range (default: whole context) - 'token_repetition_decay' (int): Repetition penalty range (default: same as range) - 'stop_conditions' (list): List of stop strings/tokens to end response (default: [EOS]) - 'max_new_tokens' (int): Max no. tokens in response (default: 150) + 'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) + 'repetition_range' (int): Repetition penalty range (default: whole context) + 'repetition_decay' (int): Repetition penalty range (default: same as range) + 'stop' (list): List of stop strings/tokens to end response (default: [EOS]) + 'max_tokens' (int): Max no. tokens in response (default: 150) 'stream_interval' (float): Interval in seconds between each output chunk (default: immediate) 'generate_window' (int): Space to reserve at the end of the model's context when generating. Rolls context window by the same amount if context length is exceeded to allow generating past @@ -223,25 +228,27 @@ class ModelContainer: """ token_healing = kwargs.get("token_healing", False) - max_new_tokens = kwargs.get("max_new_tokens", 150) + max_tokens = kwargs.get("max_tokens", 150) stream_interval = kwargs.get("stream_interval", 0) - generate_window = min(kwargs.get("generate_window", 512), max_new_tokens) + generate_window = min(kwargs.get("generate_window", 512), max_tokens) # Sampler settings gen_settings = ExLlamaV2Sampler.Settings() - gen_settings.temperature = kwargs.get("temperature", 0.8) - gen_settings.top_k = kwargs.get("top_k", 100) - gen_settings.top_p = kwargs.get("top_p", 0.8) + gen_settings.temperature = kwargs.get("temperature", 1.0) + gen_settings.top_k = kwargs.get("top_k", 1) + gen_settings.top_p = kwargs.get("top_p", 1.0) gen_settings.min_p = kwargs.get("min_p", 0.0) gen_settings.tfs = kwargs.get("tfs", 0.0) gen_settings.typical = kwargs.get("typical", 0.0) gen_settings.mirostat = kwargs.get("mirostat", False) + + # Default tau and eta fallbacks don't matter if mirostat is off gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5) gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1) - gen_settings.token_repetition_penalty = kwargs.get("token_repetition_penalty", 1.15) - gen_settings.token_repetition_range = kwargs.get("token_repetition_range", self.config.max_seq_len) - gen_settings.token_repetition_decay = kwargs.get("token_repetition_decay", gen_settings.token_repetition_range) + gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0) + gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len) + gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range) # Override sampler settings for temp = 0 @@ -253,7 +260,7 @@ class ModelContainer: # Stop conditions - self.generator.set_stop_conditions(kwargs.get("stop_conditions", [self.tokenizer.eos_token_id])) + self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id])) # Tokenized context @@ -302,10 +309,10 @@ class ModelContainer: now = time.time() elapsed = now - last_chunk_time - if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_new_tokens): + if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens): yield chunk_buffer full_response += chunk_buffer chunk_buffer = "" last_chunk_time = now - if eos or generated_tokens == max_new_tokens: break \ No newline at end of file + if eos or generated_tokens == max_tokens: break \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 90eaff93d0fc113adbc0e4fa7ddd57d992f5c7d3..25c5fc6a0f6500c32c022bd302677600a6898356 100644 GIT binary patch delta 33 kcmXr=n_whU%uvjb$`B7EOBfOviWqW$yb>TxW#D1}0D~b0>i_@% delta 4 LcmeY>nqUL~10(@M