Model + API: Migrate to use BaseSamplerParams

kwargs is pretty ugly when figuring out which arguments to use. The
base requests falls back to defaults anyways, so pass in the params
object as is.

However, since Python's typing isn't like TypeScript where types
can be transformed, the type hinting has a possiblity of None showing
up despite there always being a value for some params.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri
2025-04-16 00:50:05 -04:00
parent dcb36e9ab2
commit 3084ef9fa1
5 changed files with 113 additions and 121 deletions

View File

@@ -1,17 +1,13 @@
"""The model container class for ExLlamaV2 models.""" """The model container class for ExLlamaV2 models."""
from functools import partial
import aiofiles import aiofiles
import asyncio import asyncio
import gc import gc
import math import math
import pathlib import pathlib
import traceback import traceback
from backends.exllamav2.vision import clear_image_embedding_cache
from common.multimodal import MultimodalEmbeddingWrapper
import torch import torch
import uuid import uuid
from copy import deepcopy
from exllamav2 import ( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
ExLlamaV2Config, ExLlamaV2Config,
@@ -32,7 +28,7 @@ from exllamav2.generator import (
) )
from itertools import zip_longest from itertools import zip_longest
from loguru import logger from loguru import logger
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional
from ruamel.yaml import YAML from ruamel.yaml import YAML
@@ -47,6 +43,7 @@ from backends.exllamav2.utils import (
hardware_supports_flash_attn, hardware_supports_flash_attn,
supports_paged_attn, supports_paged_attn,
) )
from backends.exllamav2.vision import clear_image_embedding_cache
from common.concurrency import iterate_in_threadpool from common.concurrency import iterate_in_threadpool
from common.gen_logging import ( from common.gen_logging import (
log_generation_params, log_generation_params,
@@ -54,6 +51,8 @@ from common.gen_logging import (
log_prompt, log_prompt,
log_response, log_response,
) )
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import ( from common.templating import (
PromptTemplate, PromptTemplate,
TemplateLoadError, TemplateLoadError,
@@ -976,15 +975,20 @@ class ExllamaV2Container:
async def generate( async def generate(
self, self,
prompt: str,
request_id: str, request_id: str,
abort_event: asyncio.Event = None, prompt: str,
**kwargs, params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
): ):
"""Generate a response to a prompt.""" """Generate a response to a prompt."""
generations = [] generations = []
async for generation in self.generate_gen( async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs request_id,
prompt,
params,
abort_event,
mm_embeddings,
): ):
generations.append(generation) generations.append(generation)
@@ -1031,21 +1035,22 @@ class ExllamaV2Container:
return joined_generation return joined_generation
def check_unsupported_settings(self, **kwargs): def check_unsupported_settings(self, params: BaseSamplerRequest):
""" """
Check and warn the user if a sampler is unsupported. Check and warn the user if a sampler is unsupported.
Meant for dev wheels! Meant for dev wheels!
""" """
return kwargs return params
async def generate_gen( async def generate_gen(
self, self,
prompt: str,
request_id: str, request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None, abort_event: Optional[asyncio.Event] = None,
**kwargs, mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
): ):
""" """
Create generator function for prompt completion. Create generator function for prompt completion.
@@ -1059,46 +1064,43 @@ class ExllamaV2Container:
prompts = [prompt] prompts = [prompt]
token_healing = kwargs.get("token_healing") # TODO: Not used for some reason?
generate_window = max( generate_window = max(params.generate_window, self.config.max_seq_len // 8)
kwargs.get("generate_window"), self.config.max_seq_len // 8
)
# Sampler settings # Sampler settings
gen_settings = ExLlamaV2Sampler.Settings() gen_settings = ExLlamaV2Sampler.Settings()
# Check unsupported settings for dev wheels # Check unsupported settings for dev wheels
kwargs = self.check_unsupported_settings(**kwargs) params = self.check_unsupported_settings(params)
# Apply settings # Apply settings
gen_settings.temperature = kwargs.get("temperature") gen_settings.temperature = params.temperature
gen_settings.temperature_last = kwargs.get("temperature_last") gen_settings.temperature_last = params.temperature_last
gen_settings.smoothing_factor = kwargs.get("smoothing_factor") gen_settings.smoothing_factor = params.smoothing_factor
gen_settings.top_k = kwargs.get("top_k") gen_settings.top_k = params.top_k
gen_settings.top_p = kwargs.get("top_p") gen_settings.top_p = params.top_p
gen_settings.top_a = kwargs.get("top_a") gen_settings.top_a = params.top_a
gen_settings.min_p = kwargs.get("min_p") gen_settings.min_p = params.min_p
gen_settings.tfs = kwargs.get("tfs") gen_settings.tfs = params.tfs
gen_settings.typical = kwargs.get("typical") gen_settings.typical = params.typical
gen_settings.mirostat = kwargs.get("mirostat") gen_settings.mirostat = params.mirostat
gen_settings.skew = kwargs.get("skew") gen_settings.skew = params.skew
# XTC # XTC
xtc_probability = kwargs.get("xtc_probability") if params.xtc_probability > 0.0:
if xtc_probability > 0.0: gen_settings.xtc_probability = params.xtc_probability
gen_settings.xtc_probability = xtc_probability
# 0.1 is the default for this value # 0.1 is the default for this value
gen_settings.xtc_threshold = kwargs.get("xtc_threshold") gen_settings.xtc_threshold = params.xtc_threshold
# DynaTemp settings # DynaTemp settings
max_temp = kwargs.get("max_temp") max_temp = params.max_temp
min_temp = kwargs.get("min_temp") min_temp = params.min_temp
if max_temp > min_temp: if params.max_temp > params.min_temp:
gen_settings.max_temp = max_temp gen_settings.max_temp = max_temp
gen_settings.min_temp = min_temp gen_settings.min_temp = min_temp
gen_settings.temp_exponent = kwargs.get("temp_exponent") gen_settings.temp_exponent = params.temp_exponent
else: else:
# Force to default values # Force to default values
gen_settings.max_temp = 1.0 gen_settings.max_temp = 1.0
@@ -1115,11 +1117,11 @@ class ExllamaV2Container:
) )
# Default tau and eta fallbacks don't matter if mirostat is off # Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = kwargs.get("mirostat_tau") gen_settings.mirostat_tau = params.mirostat_tau
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") gen_settings.mirostat_eta = params.mirostat_eta
# Set CFG scale and negative prompt # Set CFG scale and negative prompt
cfg_scale = kwargs.get("cfg_scale") cfg_scale = params.cfg_scale
negative_prompt = None negative_prompt = None
if cfg_scale not in [None, 1.0]: if cfg_scale not in [None, 1.0]:
if self.paged: if self.paged:
@@ -1127,7 +1129,7 @@ class ExllamaV2Container:
# If the negative prompt is empty, use the BOS token # If the negative prompt is empty, use the BOS token
negative_prompt = unwrap( negative_prompt = unwrap(
kwargs.get("negative_prompt"), self.tokenizer.bos_token params.negative_prompt, self.tokenizer.bos_token
) )
prompts.append(negative_prompt) prompts.append(negative_prompt)
@@ -1138,15 +1140,16 @@ class ExllamaV2Container:
) )
# Penalties # Penalties
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") gen_settings.token_repetition_penalty = params.repetition_penalty
gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty") gen_settings.token_frequency_penalty = params.frequency_penalty
gen_settings.token_presence_penalty = kwargs.get("presence_penalty") gen_settings.token_presence_penalty = params.presence_penalty
# Applies for all penalties despite being called token_repetition_range # Applies for all penalties despite being called token_repetition_range
gen_settings.token_repetition_range = unwrap( gen_settings.token_repetition_range = unwrap(
kwargs.get("penalty_range"), self.config.max_seq_len params.penalty_range, self.config.max_seq_len
) )
# TODO: Not used for some reason?
# Dynamically scale penalty range to output tokens # Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled # Only do this if freq/pres pen is enabled
# and the repetition range is -1 # and the repetition range is -1
@@ -1164,54 +1167,51 @@ class ExllamaV2Container:
else: else:
fallback_decay = gen_settings.token_repetition_range fallback_decay = gen_settings.token_repetition_range
gen_settings.token_repetition_decay = coalesce( gen_settings.token_repetition_decay = coalesce(
kwargs.get("repetition_decay"), fallback_decay, 0 params.repetition_decay, fallback_decay, 0
) )
# DRY options # DRY options
dry_multiplier = kwargs.get("dry_multiplier") dry_multiplier = params.dry_multiplier
# < 0 = disabled # < 0 = disabled
if dry_multiplier > 0: if dry_multiplier > 0:
gen_settings.dry_multiplier = dry_multiplier gen_settings.dry_multiplier = dry_multiplier
gen_settings.dry_allowed_length = params.dry_allowed_length
gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length") gen_settings.dry_base = params.dry_base
gen_settings.dry_base = kwargs.get("dry_base")
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range # Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
# Use max_seq_len as the fallback to stay consistent # Use max_seq_len as the fallback to stay consistent
gen_settings.dry_range = unwrap( gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len)
kwargs.get("dry_range"), self.config.max_seq_len
)
# Tokenize sequence breakers # Tokenize sequence breakers
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers") if params.dry_sequence_breakers:
if dry_sequence_breakers_json:
gen_settings.dry_sequence_breakers = { gen_settings.dry_sequence_breakers = {
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers
} }
# Initialize grammar handler # Initialize grammar handler
grammar_handler = ExLlamaV2Grammar() grammar_handler = ExLlamaV2Grammar()
# Add JSON schema filter if it exists # Add JSON schema filter if it exists
json_schema = kwargs.get("json_schema") if params.json_schema:
if json_schema:
grammar_handler.add_json_schema_filter( grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer params.json_schema, self.model, self.tokenizer
) )
# Add regex filter if it exists # Add regex filter if it exists
regex_pattern = kwargs.get("regex_pattern") if params.regex_pattern:
if regex_pattern: grammar_handler.add_regex_filter(
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer) params.regex_pattern, self.model, self.tokenizer
)
# Add EBNF filter if it exists # Add EBNF filter if it exists
grammar_string = kwargs.get("grammar_string") if params.grammar_string:
if grammar_string: grammar_handler.add_kbnf_filter(
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer) params.grammar_string, self.model, self.tokenizer
)
# Set banned strings # Set banned strings
banned_strings = kwargs.get("banned_strings") banned_strings = params.banned_strings
if banned_strings and len(grammar_handler.filters) > 0: if banned_strings and len(grammar_handler.filters) > 0:
logger.warning( logger.warning(
"Disabling banned_strings because " "Disabling banned_strings because "
@@ -1220,16 +1220,12 @@ class ExllamaV2Container:
banned_strings = [] banned_strings = []
stop_conditions = kwargs.get("stop") stop_conditions = params.stop
add_bos_token = kwargs.get("add_bos_token"), True add_bos_token = params.add_bos_token
ban_eos_token = kwargs.get("ban_eos_token"), False ban_eos_token = params.ban_eos_token
logit_bias = kwargs.get("logit_bias")
# Logprobs
request_logprobs = kwargs.get("logprobs")
# Speculative Ngram # Speculative Ngram
self.generator.speculative_ngram = kwargs.get("speculative_ngram") self.generator.speculative_ngram = params.speculative_ngram
# Override sampler settings for temp = 0 # Override sampler settings for temp = 0
if gen_settings.temperature == 0: if gen_settings.temperature == 0:
@@ -1244,17 +1240,15 @@ class ExllamaV2Container:
) )
# Set banned tokens # Set banned tokens
banned_tokens = kwargs.get("banned_tokens") if params.banned_tokens:
if banned_tokens: gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens)
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
# Set allowed tokens # Set allowed tokens
allowed_tokens = kwargs.get("allowed_tokens") if params.allowed_tokens:
if allowed_tokens: gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens)
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
# Set logit bias # Set logit bias
if logit_bias: if params.logit_bias:
# Create a vocab tensor if it doesn't exist for token biasing # Create a vocab tensor if it doesn't exist for token biasing
if gen_settings.token_bias is None: if gen_settings.token_bias is None:
padding = -self.tokenizer.config.vocab_size % 32 padding = -self.tokenizer.config.vocab_size % 32
@@ -1264,7 +1258,7 @@ class ExllamaV2Container:
) )
# Map logits to the tensor with their biases # Map logits to the tensor with their biases
for token_id, bias in logit_bias.items(): for token_id, bias in params.logit_bias.items():
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)): if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
gen_settings.token_bias[token_id] = bias gen_settings.token_bias[token_id] = bias
else: else:
@@ -1289,7 +1283,7 @@ class ExllamaV2Container:
stop_conditions += eos_tokens stop_conditions += eos_tokens
# Get multimodal embeddings if present # Get multimodal embeddings if present
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") # TODO: Remove kwargs and pass this as optional
mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
# Encode both positive and negative prompts # Encode both positive and negative prompts
@@ -1312,7 +1306,7 @@ class ExllamaV2Container:
# Automatically set max_tokens to fill up the context # Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future # This should be an OK default, but may be changed in the future
max_tokens = unwrap( max_tokens = unwrap(
kwargs.get("max_tokens"), params.max_tokens,
self.config.max_seq_len - max(context_len, negative_context_len), self.config.max_seq_len - max(context_len, negative_context_len),
) )
if max_tokens < 1: if max_tokens < 1:
@@ -1349,12 +1343,6 @@ class ExllamaV2Container:
f"is greater than cache_size {self.cache_size}" f"is greater than cache_size {self.cache_size}"
) )
# Set min_tokens to generate while keeping EOS banned
min_tokens = kwargs.get("min_tokens")
# This is an inverse of skip_special_tokens
decode_special_tokens = not kwargs.get("skip_special_tokens")
# Log prompt to console. Add the BOS token if specified # Log prompt to console. Add the BOS token if specified
log_prompt( log_prompt(
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}", f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
@@ -1369,17 +1357,17 @@ class ExllamaV2Container:
self.generator, self.generator,
input_ids=input_ids, input_ids=input_ids,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
min_new_tokens=min_tokens, min_new_tokens=params.min_tokens,
gen_settings=gen_settings, gen_settings=gen_settings,
stop_conditions=stop_conditions, stop_conditions=stop_conditions,
decode_special_tokens=decode_special_tokens, decode_special_tokens=not params.skip_special_tokens,
filters=grammar_handler.filters, filters=grammar_handler.filters,
filter_prefer_eos=bool(grammar_handler.filters), filter_prefer_eos=bool(grammar_handler.filters),
return_probs=request_logprobs > 0, return_probs=params.logprobs > 0,
return_top_tokens=request_logprobs, return_top_tokens=params.logprobs,
return_logits=request_logprobs > 0, return_logits=params.logprobs > 0,
banned_strings=banned_strings, banned_strings=banned_strings,
token_healing=token_healing, token_healing=params.token_healing,
identifier=job_id, identifier=job_id,
embeddings=mm_embeddings_content, embeddings=mm_embeddings_content,
) )
@@ -1418,7 +1406,7 @@ class ExllamaV2Container:
"offset": len(full_response), "offset": len(full_response),
} }
if request_logprobs > 0: if params.logprobs > 0:
# Get top tokens and probs # Get top tokens and probs
top_tokens = unwrap( top_tokens = unwrap(
result.get("top_k_tokens"), result.get("top_k_tokens"),
@@ -1494,8 +1482,7 @@ class ExllamaV2Container:
request_id=request_id, request_id=request_id,
bos_token_id=self.tokenizer.bos_token_id, bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens, eos_token_id=eos_tokens,
**kwargs, **params.model_dump(),
generate_window=generate_window,
auto_scale_penalty_range=auto_scale_penalty_range, auto_scale_penalty_range=auto_scale_penalty_range,
) )

View File

@@ -282,6 +282,11 @@ class BaseSamplerRequest(BaseModel):
ge=0, ge=0,
) )
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0),
ge=0,
)
@field_validator("top_k", mode="before") @field_validator("top_k", mode="before")
def convert_top_k(cls, v): def convert_top_k(cls, v):
"""Fixes instance if Top-K is -1.""" """Fixes instance if Top-K is -1."""

View File

@@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass) # Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[ChatCompletionStreamOptions] = None stream_options: Optional[ChatCompletionStreamOptions] = None
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0),
ge=0,
)
response_format: Optional[CompletionResponseFormat] = Field( response_format: Optional[CompletionResponseFormat] = Field(
default_factory=CompletionResponseFormat default_factory=CompletionResponseFormat
) )

View File

@@ -333,11 +333,11 @@ async def stream_generate_chat_completion(
_stream_collector( _stream_collector(
n, n,
gen_queue, gen_queue,
prompt,
request.state.id, request.state.id,
prompt,
task_gen_params,
abort_event, abort_event,
embeddings=embeddings, mm_embeddings=embeddings,
**task_gen_params.model_dump(exclude={"prompt"}),
) )
) )
@@ -422,10 +422,10 @@ async def generate_chat_completion(
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate( model.container.generate(
prompt,
request.state.id, request.state.id,
embeddings=embeddings, prompt,
**data.model_dump(exclude={"prompt"}), data,
mm_embeddings=embeddings,
) )
) )
) )
@@ -465,7 +465,6 @@ async def generate_tool_calls(
# FIXME: May not be necessary depending on how the codebase evolves # FIXME: May not be necessary depending on how the codebase evolves
tool_data = data.model_copy(deep=True) tool_data = data.model_copy(deep=True)
tool_data.json_schema = tool_data.tool_call_schema tool_data.json_schema = tool_data.tool_call_schema
gen_params = tool_data.model_dump()
for idx, gen in enumerate(generations): for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start: if gen["stop_str"] in tool_data.tool_call_start:
@@ -488,10 +487,10 @@ async def generate_tool_calls(
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate( model.container.generate(
pre_tool_prompt,
request.state.id, request.state.id,
pre_tool_prompt,
tool_data,
embeddings=mm_embeddings, embeddings=mm_embeddings,
**gen_params,
) )
) )
) )

View File

@@ -8,12 +8,12 @@ import asyncio
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from typing import List, Union
from loguru import logger from loguru import logger
from typing import List, Optional, Union
from common import model from common import model
from common.auth import get_key_permission from common.auth import get_key_permission
from common.multimodal import MultimodalEmbeddingWrapper
from common.networking import ( from common.networking import (
get_generator_error, get_generator_error,
handle_request_disconnect, handle_request_disconnect,
@@ -86,16 +86,21 @@ def _create_response(
async def _stream_collector( async def _stream_collector(
task_idx: int, task_idx: int,
gen_queue: asyncio.Queue, gen_queue: asyncio.Queue,
prompt: str,
request_id: str, request_id: str,
prompt: str,
params: CompletionRequest,
abort_event: asyncio.Event, abort_event: asyncio.Event,
**kwargs, mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
): ):
"""Collects a stream and places results in a common queue""" """Collects a stream and places results in a common queue"""
try: try:
new_generation = model.container.generate_gen( new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs request_id,
prompt,
params,
abort_event,
mm_embeddings,
) )
async for generation in new_generation: async for generation in new_generation:
generation["index"] = task_idx generation["index"] = task_idx
@@ -195,10 +200,10 @@ async def stream_generate_completion(
_stream_collector( _stream_collector(
n, n,
gen_queue, gen_queue,
data.prompt,
request.state.id, request.state.id,
data.prompt,
task_gen_params,
abort_event, abort_event,
**task_gen_params.model_dump(exclude={"prompt"}),
) )
) )
@@ -256,9 +261,9 @@ async def generate_completion(
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate( model.container.generate(
data.prompt,
request.state.id, request.state.id,
**task_gen_params.model_dump(exclude={"prompt"}), data.prompt,
task_gen_params,
) )
) )
) )