diff --git a/OAI/types/common.py b/OAI/types/common.py index 9c5ebbd..5cb2742 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -1,19 +1,10 @@ """ Common types for OAI. """ from pydantic import BaseModel, Field -from typing import List, Dict, Optional +from typing import Optional from common.sampling import BaseSamplerRequest -class LogProbs(BaseModel): - """Represents log probabilities.""" - - text_offset: List[int] = Field(default_factory=list) - token_logprobs: List[Optional[float]] = Field(default_factory=list) - tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) - - class UsageStats(BaseModel): """Represents usage stats.""" @@ -29,6 +20,10 @@ class CommonCompletionRequest(BaseSamplerRequest): # This parameter is not used, the loaded model is used instead model: Optional[str] = None + # Generation info (remainder is in BaseSamplerRequest superclass) + stream: Optional[bool] = False + logprobs: Optional[int] = 0 + # Extra OAI request stuff best_of: Optional[int] = Field( description="Not parsed. Only used for OAI compliance.", default=None @@ -36,9 +31,6 @@ class CommonCompletionRequest(BaseSamplerRequest): echo: Optional[bool] = Field( description="Not parsed. Only used for OAI compliance.", default=False ) - logprobs: Optional[int] = Field( - description="Not parsed. Only used for OAI compliance.", default=None - ) n: Optional[int] = Field( description="Not parsed. Only used for OAI compliance.", default=1 ) @@ -49,5 +41,7 @@ class CommonCompletionRequest(BaseSamplerRequest): description="Not parsed. Only used for OAI compliance.", default=None ) - # Generation info (remainder is in BaseSamplerRequest superclass) - stream: Optional[bool] = False + def to_gen_params(self): + extra_gen_params = {"logprobs": self.logprobs} + + return super().to_gen_params(**extra_gen_params) diff --git a/OAI/types/completion.py b/OAI/types/completion.py index 4fa380c..4675ffc 100644 --- a/OAI/types/completion.py +++ b/OAI/types/completion.py @@ -1,10 +1,19 @@ """ Completion API protocols """ from pydantic import BaseModel, Field from time import time -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from uuid import uuid4 -from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats +from OAI.types.common import CommonCompletionRequest, UsageStats + + +class CompletionLogProbs(BaseModel): + """Represents log probabilities for a completion request.""" + + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) class CompletionRespChoice(BaseModel): @@ -13,7 +22,7 @@ class CompletionRespChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 finish_reason: str - logprobs: Optional[LogProbs] = None + logprobs: Optional[CompletionLogProbs] = None text: str diff --git a/OAI/utils/completion.py b/OAI/utils/completion.py index 500451a..fddf666 100644 --- a/OAI/utils/completion.py +++ b/OAI/utils/completion.py @@ -9,22 +9,40 @@ from OAI.types.chat_completion import ( ChatCompletionResponse, ChatCompletionStreamChoice, ) -from OAI.types.completion import CompletionResponse, CompletionRespChoice +from OAI.types.completion import ( + CompletionResponse, + CompletionRespChoice, + CompletionLogProbs, +) from OAI.types.common import UsageStats -def create_completion_response( - text: str, - prompt_tokens: int, - completion_tokens: int, - model_name: Optional[str], -): +def create_completion_response(**kwargs): """Create a completion response from the provided text.""" - choice = CompletionRespChoice(finish_reason="Generated", text=text) + + token_probs = unwrap(kwargs.get("token_probs"), {}) + logprobs = unwrap(kwargs.get("logprobs"), []) + offset = unwrap(kwargs.get("offset"), []) + + logprob_response = CompletionLogProbs( + text_offset=offset if isinstance(offset, list) else [offset], + token_logprobs=token_probs.values(), + tokens=token_probs.keys(), + top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs], + ) + + choice = CompletionRespChoice( + finish_reason="Generated", + text=unwrap(kwargs.get("text"), ""), + logprobs=logprob_response, + ) + + prompt_tokens = unwrap(kwargs.get("prompt_tokens"), 0) + completion_tokens = unwrap(kwargs.get("completion_tokens"), 0) response = CompletionResponse( choices=[choice], - model=unwrap(model_name, ""), + model=unwrap(kwargs.get("model_name"), ""), usage=UsageStats( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, @@ -37,12 +55,12 @@ def create_completion_response( def create_chat_completion_response( text: str, - prompt_tokens: int, - completion_tokens: int, + prompt_tokens: Optional[int], + completion_tokens: Optional[int], model_name: Optional[str], ): """Create a chat completion response from the provided text.""" - message = ChatCompletionMessage(role="assistant", content=text) + message = ChatCompletionMessage(role="assistant", content=unwrap(text, "")) choice = ChatCompletionRespChoice(finish_reason="Generated", message=message) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 356e984..e86f4cb 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -472,20 +472,72 @@ class ExllamaV2Container: "unk_token": self.tokenizer.unk_token, } + def get_logprobs(self, logits: torch.Tensor, max_logprobs: int): + normalized_logits = torch.log_softmax(logits, dim=-1) + top_values, top_ids = torch.topk(normalized_logits, max_logprobs, dim=-1) + + top_tokens = list( + map( + lambda index: self.tokenizer.extended_id_to_piece.get( + index, self.tokenizer.id_to_piece[index] + ), + top_ids[0].tolist(), + ) + ) + top_values = top_values[0].tolist() + + return dict(zip(top_tokens, top_values, strict=True)) + + def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor): + tokens = list( + map( + lambda index: self.tokenizer.extended_id_to_piece.get( + index, self.tokenizer.id_to_piece[index] + ), + token_ids[0].tolist(), + ) + ) + + return dict(zip(tokens, token_probs[0].tolist(), strict=True)) + + def generate(self, prompt: str, **kwargs): + """Generate a response to a prompt""" + generations = list(self.generate_gen(prompt, **kwargs)) + + joined_generation = { + "chunk": "", + "prompt_tokens": 0, + "generation_tokens": 0, + "offset": [], + "token_probs": {}, + "logprobs": [], + } + + if generations: + for generation in generations: + joined_generation["chunk"] += unwrap(generation.get("chunk"), "") + joined_generation["offset"].append(unwrap(generation.get("offset"), [])) + joined_generation["token_probs"].update( + unwrap(generation.get("token_probs"), {}) + ) + joined_generation["logprobs"].append( + unwrap(generation.get("logprobs"), {}) + ) + + joined_generation["prompt_tokens"] = unwrap( + generations[-1].get("prompt_tokens"), 0 + ) + joined_generation["generation_tokens"] = unwrap( + generations[-1].get("generated_tokens"), 0 + ) + + return joined_generation + def check_unsupported_settings(self, **kwargs): """Check and warn the user if a sampler is unsupported. Meant for dev wheels!""" pass - def generate(self, prompt: str, **kwargs): - """Generate a response to a prompt""" - generation = list(self.generate_gen(prompt, **kwargs)) - if generation: - response = "".join(map(lambda chunk: chunk[0], generation)) - return response, generation[-1][1], generation[-1][2] - - return "", 0, 0 - # pylint: disable=too-many-locals,too-many-branches,too-many-statements def generate_gen(self, prompt: str, **kwargs): """ @@ -639,6 +691,7 @@ class ExllamaV2Container: add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) logit_bias = kwargs.get("logit_bias") + request_logprobs = unwrap(kwargs.get("logprobs"), 0) # Override sampler settings for temp = 0 if gen_settings.temperature == 0: @@ -657,6 +710,7 @@ class ExllamaV2Container: generate_window=generate_window, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, + logprobs=request_logprobs, stop_conditions=stop_conditions, logit_bias=logit_bias, ) @@ -758,7 +812,7 @@ class ExllamaV2Container: gen_settings.token_repetition_range = generated_tokens # Generate - chunk, eos, tokens, _, _ = self.generator.stream() + chunk, eos, tokens, token_probs, logits = self.generator.stream() if token_healing: # Extract healed token @@ -780,7 +834,27 @@ class ExllamaV2Container: if chunk_buffer != "" and ( elapsed > stream_interval or eos or generated_tokens == max_tokens ): - yield chunk_buffer, prompt_tokens, generated_tokens + generation = { + "chunk": chunk_buffer, + "prompt_tokens": prompt_tokens, + "generated_tokens": generated_tokens, + "offset": len(full_response), + } + + if request_logprobs > 0: + # Get sampled token probs + if token_probs.numel() > 0 and tokens.numel() > 0: + generation["token_probs"] = self.get_token_probs( + tokens, token_probs + ) + + # Get logprob choices + if logits.numel() > 0: + generation["logprobs"] = self.get_logprobs( + logits, request_logprobs + ) + + yield generation full_response += chunk_buffer chunk_buffer = "" last_chunk_time = now diff --git a/common/sampling.py b/common/sampling.py index 5824ded..c818b83 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -159,7 +159,7 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) - def to_gen_params(self): + def to_gen_params(self, **kwargs): """Converts samplers to internal generation params""" # Add forced overrides if present @@ -201,7 +201,7 @@ class BaseSamplerRequest(BaseModel): "negative_prompt": self.negative_prompt, } - return gen_params + return {**gen_params, **kwargs} # Global for default overrides diff --git a/main.py b/main.py index 71a90da..e875a06 100644 --- a/main.py +++ b/main.py @@ -458,12 +458,13 @@ async def generate_completion(request: Request, data: CompletionRequest): new_generation = MODEL_CONTAINER.generate_gen( data.prompt, **data.to_gen_params() ) - for part, prompt_tokens, completion_tokens in new_generation: + for generation in new_generation: if await request.is_disconnected(): break response = create_completion_response( - part, prompt_tokens, completion_tokens, model_path.name + **generation, + model_name=model_path.name, ) yield get_sse_packet(response.model_dump_json()) @@ -479,13 +480,10 @@ async def generate_completion(request: Request, data: CompletionRequest): generate_with_semaphore(generator), media_type="text/event-stream" ) - response_text, prompt_tokens, completion_tokens = await call_with_semaphore( + generation = await call_with_semaphore( partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params()) ) - - response = create_completion_response( - response_text, prompt_tokens, completion_tokens, model_path.name - ) + response = create_completion_response(**generation) return response @@ -545,12 +543,12 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest new_generation = MODEL_CONTAINER.generate_gen( prompt, **data.to_gen_params() ) - for part, _, _ in new_generation: + for generation in new_generation: if await request.is_disconnected(): break response = create_chat_completion_stream_chunk( - const_id, part, model_path.name + const_id, generation.get("chunk"), model_path.name ) yield get_sse_packet(response.model_dump_json())