diff --git a/OAI/models/common.py b/OAI/models/common.py new file mode 100644 index 0000000..a86341d --- /dev/null +++ b/OAI/models/common.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from typing import List, Dict + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[float] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Dict[str, float]] = Field(default_factory=list) + +class UsageStats(BaseModel): + completion_tokens: int + prompt_tokens: int + total_tokens: int diff --git a/OAI/models/completions.py b/OAI/models/completions.py new file mode 100644 index 0000000..ba107f6 --- /dev/null +++ b/OAI/models/completions.py @@ -0,0 +1,99 @@ +from uuid import uuid4 +from time import time +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Union +from OAI.models.common import LogProbs, UsageStats + +class CompletionRespChoice(BaseModel): + finish_reason: str + index: int + logprobs: Optional[LogProbs] = None + text: str + +class CompletionRequest(BaseModel): + # Model information + model: str + + # Prompt can also contain token ids, but that's out of scope for this project. + prompt: Union[str, List[str]] + + # Extra OAI request stuff + best_of: Optional[int] = None + echo: Optional[bool] = False + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + n: Optional[int] = 1 + suffix: Optional[str] = None + user: Optional[str] = None + + # Generation info + seed: Optional[int] = -1 + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + + # Default to 150 as 16 makes no sense as a default + max_tokens: Optional[int] = 150 + + # Not supported sampling params + presence_penalty: Optional[int] = 0 + + # Aliased to repetition_penalty + frequency_penalty: int = 0 + + # Sampling params + token_healing: Optional[bool] = False + temperature: Optional[float] = 1.0 + top_k: Optional[int] = 0 + top_p: Optional[float] = 1.0 + typical: Optional[float] = 0.0 + min_p: Optional[float] = 0.0 + tfs: Optional[float] = 1.0 + repetition_penalty: Optional[float] = 1.0 + repetition_penalty_range: Optional[int] = 0 + repetition_decay: Optional[int] = 0 + mirostat_mode: Optional[int] = 0 + mirostat_tau: Optional[float] = 1.5 + mirostat_eta: Optional[float] = 0.1 + + # Converts to internal generation parameters + def to_gen_params(self): + # Convert prompt to a string + if isinstance(self.prompt, list): + self.prompt = "\n".join(self.prompt) + + # Convert stop to an array of strings + if isinstance(self.stop, str): + self.stop = [self.stop] + + # Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined + if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty: + self.repetition_penalty = self.frequency_penalty + + return { + "prompt": self.prompt, + "stop": self.stop, + "max_tokens": self.max_tokens, + "token_healing": self.token_healing, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + "typical": self.typical, + "min_p": self.min_p, + "tfs": self.tfs, + "repetition_penalty": self.repetition_penalty, + "repetition_penalty_range": self.repetition_penalty_range, + "repetition_decay": self.repetition_decay, + "mirostat": True if self.mirostat_mode == 2 else False, + "mirostat_tau": self.mirostat_tau, + "mirostat_eta": self.mirostat_eta + } + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}") + choices: List[CompletionRespChoice] + created: int = Field(default_factory=lambda: int(time())) + model: str + object: str = "text-completion" + + # TODO: Add usage stats + usage: Optional[UsageStats] = None diff --git a/OAI/utils.py b/OAI/utils.py new file mode 100644 index 0000000..ebd80e3 --- /dev/null +++ b/OAI/utils.py @@ -0,0 +1,19 @@ +from OAI.models.completions import CompletionResponse, CompletionRespChoice +from OAI.models.common import UsageStats +from typing import Optional + +def create_completion_response(text: str, index: int, model_name: Optional[str]): + # TODO: Add method to get token amounts in model for UsageStats + + choice = CompletionRespChoice( + finish_reason="Generated", + index = index, + text = text + ) + + response = CompletionResponse( + choices = [choice], + model = model_name or "" + ) + + return response diff --git a/main.py b/main.py index b660547..d9b53e3 100644 --- a/main.py +++ b/main.py @@ -1,41 +1,37 @@ import uvicorn import yaml -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +from fastapi import FastAPI, Request from model import ModelContainer from progress.bar import IncrementalBar +from sse_starlette import EventSourceResponse +from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice +from OAI.utils import create_completion_response app = FastAPI() # Initialize a model container. This can be undefined at any period of time model_container: ModelContainer = None -class TextRequest(BaseModel): - model: str = None # Make the "model" field optional with a default value of None - prompt: str - max_tokens: int = 200 - temperature: float = 1 - top_p: float = 0.9 - seed: int = 10 - stream: bool = False - token_repetition_penalty: float = 1.0 - stop: list = None +@app.post("/v1/completions") +async def generate_completion(request: Request, data: CompletionRequest): + if data.stream: + async def generator(): + new_generation = model_container.generate_gen(**data.to_gen_params()) + for index, part in enumerate(new_generation): + if await request.is_disconnected(): + break -class TextResponse(BaseModel): - response: str - generation_time: str + response = create_completion_response(part, index, model_container.get_model_name()) + + yield response.model_dump_json() + + return EventSourceResponse(generator()) + else: + response_text = model_container.generate(**data.to_gen_params()) + response = create_completion_response(response_text, 0, model_container.get_model_name()) + + return response.model_dump_json() -# TODO: Currently broken -@app.post("/generate-text", response_model=TextResponse) -def generate_text(request: TextRequest): - global modelManager - try: - prompt = request.prompt # Get the prompt from the request - user_message = prompt # Assuming that prompt is equivalent to the user's message - output, generation_time = modelManager.generate_text(prompt=user_message) - return {"response": output, "generation_time": generation_time} - except RuntimeError as e: - raise HTTPException(status_code=500, detail=str(e)) # Wrapper callback for load progress def load_progress(module, modules): @@ -63,5 +59,4 @@ if __name__ == "__main__": print("Model successfully loaded.") - # Reload is for dev purposes ONLY! - uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True) + uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug") diff --git a/model.py b/model.py index fb222d0..b7368b0 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,5 @@ import gc, time import torch - from exllamav2 import( ExLlamaV2, ExLlamaV2Config, @@ -8,25 +7,26 @@ from exllamav2 import( ExLlamaV2Cache_8bit, ExLlamaV2Tokenizer, ) - from exllamav2.generator import( ExLlamaV2StreamingGenerator, ExLlamaV2Sampler ) +from os import path +from typing import Optional # Bytes to reserve on first device when loading with auto split auto_split_reserve_bytes = 96 * 1024**2 class ModelContainer: - config: ExLlamaV2Config or None = None - draft_config: ExLlamaV2Config or None = None - model: ExLlamaV2 or None = None - draft_model: ExLlamaV2 or None = None - cache: ExLlamaV2Cache or None = None - draft_cache: ExLlamaV2Cache or None = None - tokenizer: ExLlamaV2Tokenizer or None = None - generator: ExLlamaV2StreamingGenerator or None = None + config: Optional[ExLlamaV2Config] = None + draft_config: Optional[ExLlamaV2Config] = None + model: Optional[ExLlamaV2] = None + draft_model: Optional[ExLlamaV2] = None + cache: Optional[ExLlamaV2Cache] = None + draft_cache: Optional[ExLlamaV2Cache] = None + tokenizer: Optional[ExLlamaV2Tokenizer] = None + generator: Optional[ExLlamaV2StreamingGenerator] = None cache_fp8: bool = False draft_enabled: bool = False @@ -102,6 +102,11 @@ class ModelContainer: self.draft_config.max_input_len = kwargs["chunk_size"] self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 + def get_model_name(self): + if self.draft_enabled: + return path.basename(path.normpath(self.draft_config.model_dir)) + else: + return path.basename(path.normpath(self.config.model_dir)) def load(self, progress_callback = None): """ @@ -201,20 +206,20 @@ class ModelContainer: prompt (str): Input prompt **kwargs: 'token_healing' (bool): Use token healing (default: False) - 'temperature' (float): Sampling temperature (default: 0.8) - 'top_k' (int): Sampling top-K (default: 100) - 'top_p' (float): Sampling top-P (default: 0.8) + 'temperature' (float): Sampling temperature (default: 1.0) + 'top_k' (int): Sampling top-K (default: 0) + 'top_p' (float): Sampling top-P (default: 1.0) 'min_p' (float): Sampling min-P (default: 0.0) 'tfs' (float): Tail-free sampling (default: 0.0) 'typical' (float): Sampling typical (default: 0.0) 'mirostat' (bool): Use Mirostat (default: False) 'mirostat_tau' (float) Mirostat tau parameter (default: 1.5) 'mirostat_eta' (float) Mirostat eta parameter (default: 0.1) - 'token_repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) - 'token_repetition_range' (int): Repetition penalty range (default: whole context) - 'token_repetition_decay' (int): Repetition penalty range (default: same as range) - 'stop_conditions' (list): List of stop strings/tokens to end response (default: [EOS]) - 'max_new_tokens' (int): Max no. tokens in response (default: 150) + 'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) + 'repetition_range' (int): Repetition penalty range (default: whole context) + 'repetition_decay' (int): Repetition penalty range (default: same as range) + 'stop' (list): List of stop strings/tokens to end response (default: [EOS]) + 'max_tokens' (int): Max no. tokens in response (default: 150) 'stream_interval' (float): Interval in seconds between each output chunk (default: immediate) 'generate_window' (int): Space to reserve at the end of the model's context when generating. Rolls context window by the same amount if context length is exceeded to allow generating past @@ -223,25 +228,27 @@ class ModelContainer: """ token_healing = kwargs.get("token_healing", False) - max_new_tokens = kwargs.get("max_new_tokens", 150) + max_tokens = kwargs.get("max_tokens", 150) stream_interval = kwargs.get("stream_interval", 0) - generate_window = min(kwargs.get("generate_window", 512), max_new_tokens) + generate_window = min(kwargs.get("generate_window", 512), max_tokens) # Sampler settings gen_settings = ExLlamaV2Sampler.Settings() - gen_settings.temperature = kwargs.get("temperature", 0.8) - gen_settings.top_k = kwargs.get("top_k", 100) - gen_settings.top_p = kwargs.get("top_p", 0.8) + gen_settings.temperature = kwargs.get("temperature", 1.0) + gen_settings.top_k = kwargs.get("top_k", 1) + gen_settings.top_p = kwargs.get("top_p", 1.0) gen_settings.min_p = kwargs.get("min_p", 0.0) gen_settings.tfs = kwargs.get("tfs", 0.0) gen_settings.typical = kwargs.get("typical", 0.0) gen_settings.mirostat = kwargs.get("mirostat", False) + + # Default tau and eta fallbacks don't matter if mirostat is off gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5) gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1) - gen_settings.token_repetition_penalty = kwargs.get("token_repetition_penalty", 1.15) - gen_settings.token_repetition_range = kwargs.get("token_repetition_range", self.config.max_seq_len) - gen_settings.token_repetition_decay = kwargs.get("token_repetition_decay", gen_settings.token_repetition_range) + gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0) + gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len) + gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range) # Override sampler settings for temp = 0 @@ -253,7 +260,7 @@ class ModelContainer: # Stop conditions - self.generator.set_stop_conditions(kwargs.get("stop_conditions", [self.tokenizer.eos_token_id])) + self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id])) # Tokenized context @@ -302,10 +309,10 @@ class ModelContainer: now = time.time() elapsed = now - last_chunk_time - if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_new_tokens): + if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens): yield chunk_buffer full_response += chunk_buffer chunk_buffer = "" last_chunk_time = now - if eos or generated_tokens == max_new_tokens: break \ No newline at end of file + if eos or generated_tokens == max_tokens: break \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 90eaff9..25c5fc6 100644 Binary files a/requirements.txt and b/requirements.txt differ