mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-30 11:11:35 +00:00
Model + API: Migrate to use BaseSamplerParams
kwargs is pretty ugly when figuring out which arguments to use. The base requests falls back to defaults anyways, so pass in the params object as is. However, since Python's typing isn't like TypeScript where types can be transformed, the type hinting has a possiblity of None showing up despite there always being a value for some params. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -1,17 +1,13 @@
|
|||||||
"""The model container class for ExLlamaV2 models."""
|
"""The model container class for ExLlamaV2 models."""
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import pathlib
|
import pathlib
|
||||||
import traceback
|
import traceback
|
||||||
from backends.exllamav2.vision import clear_image_embedding_cache
|
|
||||||
from common.multimodal import MultimodalEmbeddingWrapper
|
|
||||||
import torch
|
import torch
|
||||||
import uuid
|
import uuid
|
||||||
from copy import deepcopy
|
|
||||||
from exllamav2 import (
|
from exllamav2 import (
|
||||||
ExLlamaV2,
|
ExLlamaV2,
|
||||||
ExLlamaV2Config,
|
ExLlamaV2Config,
|
||||||
@@ -32,7 +28,7 @@ from exllamav2.generator import (
|
|||||||
)
|
)
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
@@ -47,6 +43,7 @@ from backends.exllamav2.utils import (
|
|||||||
hardware_supports_flash_attn,
|
hardware_supports_flash_attn,
|
||||||
supports_paged_attn,
|
supports_paged_attn,
|
||||||
)
|
)
|
||||||
|
from backends.exllamav2.vision import clear_image_embedding_cache
|
||||||
from common.concurrency import iterate_in_threadpool
|
from common.concurrency import iterate_in_threadpool
|
||||||
from common.gen_logging import (
|
from common.gen_logging import (
|
||||||
log_generation_params,
|
log_generation_params,
|
||||||
@@ -54,6 +51,8 @@ from common.gen_logging import (
|
|||||||
log_prompt,
|
log_prompt,
|
||||||
log_response,
|
log_response,
|
||||||
)
|
)
|
||||||
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
|
from common.sampling import BaseSamplerRequest
|
||||||
from common.templating import (
|
from common.templating import (
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
TemplateLoadError,
|
TemplateLoadError,
|
||||||
@@ -976,15 +975,20 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
|
||||||
request_id: str,
|
request_id: str,
|
||||||
abort_event: asyncio.Event = None,
|
prompt: str,
|
||||||
**kwargs,
|
params: BaseSamplerRequest,
|
||||||
|
abort_event: Optional[asyncio.Event] = None,
|
||||||
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||||
):
|
):
|
||||||
"""Generate a response to a prompt."""
|
"""Generate a response to a prompt."""
|
||||||
generations = []
|
generations = []
|
||||||
async for generation in self.generate_gen(
|
async for generation in self.generate_gen(
|
||||||
prompt, request_id, abort_event, **kwargs
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params,
|
||||||
|
abort_event,
|
||||||
|
mm_embeddings,
|
||||||
):
|
):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
@@ -1031,21 +1035,22 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
return joined_generation
|
return joined_generation
|
||||||
|
|
||||||
def check_unsupported_settings(self, **kwargs):
|
def check_unsupported_settings(self, params: BaseSamplerRequest):
|
||||||
"""
|
"""
|
||||||
Check and warn the user if a sampler is unsupported.
|
Check and warn the user if a sampler is unsupported.
|
||||||
|
|
||||||
Meant for dev wheels!
|
Meant for dev wheels!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return kwargs
|
return params
|
||||||
|
|
||||||
async def generate_gen(
|
async def generate_gen(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
prompt: str,
|
||||||
|
params: BaseSamplerRequest,
|
||||||
abort_event: Optional[asyncio.Event] = None,
|
abort_event: Optional[asyncio.Event] = None,
|
||||||
**kwargs,
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create generator function for prompt completion.
|
Create generator function for prompt completion.
|
||||||
@@ -1059,46 +1064,43 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
prompts = [prompt]
|
prompts = [prompt]
|
||||||
|
|
||||||
token_healing = kwargs.get("token_healing")
|
# TODO: Not used for some reason?
|
||||||
generate_window = max(
|
generate_window = max(params.generate_window, self.config.max_seq_len // 8)
|
||||||
kwargs.get("generate_window"), self.config.max_seq_len // 8
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sampler settings
|
# Sampler settings
|
||||||
gen_settings = ExLlamaV2Sampler.Settings()
|
gen_settings = ExLlamaV2Sampler.Settings()
|
||||||
|
|
||||||
# Check unsupported settings for dev wheels
|
# Check unsupported settings for dev wheels
|
||||||
kwargs = self.check_unsupported_settings(**kwargs)
|
params = self.check_unsupported_settings(params)
|
||||||
|
|
||||||
# Apply settings
|
# Apply settings
|
||||||
gen_settings.temperature = kwargs.get("temperature")
|
gen_settings.temperature = params.temperature
|
||||||
gen_settings.temperature_last = kwargs.get("temperature_last")
|
gen_settings.temperature_last = params.temperature_last
|
||||||
gen_settings.smoothing_factor = kwargs.get("smoothing_factor")
|
gen_settings.smoothing_factor = params.smoothing_factor
|
||||||
gen_settings.top_k = kwargs.get("top_k")
|
gen_settings.top_k = params.top_k
|
||||||
gen_settings.top_p = kwargs.get("top_p")
|
gen_settings.top_p = params.top_p
|
||||||
gen_settings.top_a = kwargs.get("top_a")
|
gen_settings.top_a = params.top_a
|
||||||
gen_settings.min_p = kwargs.get("min_p")
|
gen_settings.min_p = params.min_p
|
||||||
gen_settings.tfs = kwargs.get("tfs")
|
gen_settings.tfs = params.tfs
|
||||||
gen_settings.typical = kwargs.get("typical")
|
gen_settings.typical = params.typical
|
||||||
gen_settings.mirostat = kwargs.get("mirostat")
|
gen_settings.mirostat = params.mirostat
|
||||||
gen_settings.skew = kwargs.get("skew")
|
gen_settings.skew = params.skew
|
||||||
|
|
||||||
# XTC
|
# XTC
|
||||||
xtc_probability = kwargs.get("xtc_probability")
|
if params.xtc_probability > 0.0:
|
||||||
if xtc_probability > 0.0:
|
gen_settings.xtc_probability = params.xtc_probability
|
||||||
gen_settings.xtc_probability = xtc_probability
|
|
||||||
|
|
||||||
# 0.1 is the default for this value
|
# 0.1 is the default for this value
|
||||||
gen_settings.xtc_threshold = kwargs.get("xtc_threshold")
|
gen_settings.xtc_threshold = params.xtc_threshold
|
||||||
|
|
||||||
# DynaTemp settings
|
# DynaTemp settings
|
||||||
max_temp = kwargs.get("max_temp")
|
max_temp = params.max_temp
|
||||||
min_temp = kwargs.get("min_temp")
|
min_temp = params.min_temp
|
||||||
|
|
||||||
if max_temp > min_temp:
|
if params.max_temp > params.min_temp:
|
||||||
gen_settings.max_temp = max_temp
|
gen_settings.max_temp = max_temp
|
||||||
gen_settings.min_temp = min_temp
|
gen_settings.min_temp = min_temp
|
||||||
gen_settings.temp_exponent = kwargs.get("temp_exponent")
|
gen_settings.temp_exponent = params.temp_exponent
|
||||||
else:
|
else:
|
||||||
# Force to default values
|
# Force to default values
|
||||||
gen_settings.max_temp = 1.0
|
gen_settings.max_temp = 1.0
|
||||||
@@ -1115,11 +1117,11 @@ class ExllamaV2Container:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau")
|
gen_settings.mirostat_tau = params.mirostat_tau
|
||||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta")
|
gen_settings.mirostat_eta = params.mirostat_eta
|
||||||
|
|
||||||
# Set CFG scale and negative prompt
|
# Set CFG scale and negative prompt
|
||||||
cfg_scale = kwargs.get("cfg_scale")
|
cfg_scale = params.cfg_scale
|
||||||
negative_prompt = None
|
negative_prompt = None
|
||||||
if cfg_scale not in [None, 1.0]:
|
if cfg_scale not in [None, 1.0]:
|
||||||
if self.paged:
|
if self.paged:
|
||||||
@@ -1127,7 +1129,7 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
# If the negative prompt is empty, use the BOS token
|
# If the negative prompt is empty, use the BOS token
|
||||||
negative_prompt = unwrap(
|
negative_prompt = unwrap(
|
||||||
kwargs.get("negative_prompt"), self.tokenizer.bos_token
|
params.negative_prompt, self.tokenizer.bos_token
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts.append(negative_prompt)
|
prompts.append(negative_prompt)
|
||||||
@@ -1138,15 +1140,16 @@ class ExllamaV2Container:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Penalties
|
# Penalties
|
||||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty")
|
gen_settings.token_repetition_penalty = params.repetition_penalty
|
||||||
gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty")
|
gen_settings.token_frequency_penalty = params.frequency_penalty
|
||||||
gen_settings.token_presence_penalty = kwargs.get("presence_penalty")
|
gen_settings.token_presence_penalty = params.presence_penalty
|
||||||
|
|
||||||
# Applies for all penalties despite being called token_repetition_range
|
# Applies for all penalties despite being called token_repetition_range
|
||||||
gen_settings.token_repetition_range = unwrap(
|
gen_settings.token_repetition_range = unwrap(
|
||||||
kwargs.get("penalty_range"), self.config.max_seq_len
|
params.penalty_range, self.config.max_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Not used for some reason?
|
||||||
# Dynamically scale penalty range to output tokens
|
# Dynamically scale penalty range to output tokens
|
||||||
# Only do this if freq/pres pen is enabled
|
# Only do this if freq/pres pen is enabled
|
||||||
# and the repetition range is -1
|
# and the repetition range is -1
|
||||||
@@ -1164,54 +1167,51 @@ class ExllamaV2Container:
|
|||||||
else:
|
else:
|
||||||
fallback_decay = gen_settings.token_repetition_range
|
fallback_decay = gen_settings.token_repetition_range
|
||||||
gen_settings.token_repetition_decay = coalesce(
|
gen_settings.token_repetition_decay = coalesce(
|
||||||
kwargs.get("repetition_decay"), fallback_decay, 0
|
params.repetition_decay, fallback_decay, 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# DRY options
|
# DRY options
|
||||||
dry_multiplier = kwargs.get("dry_multiplier")
|
dry_multiplier = params.dry_multiplier
|
||||||
|
|
||||||
# < 0 = disabled
|
# < 0 = disabled
|
||||||
if dry_multiplier > 0:
|
if dry_multiplier > 0:
|
||||||
gen_settings.dry_multiplier = dry_multiplier
|
gen_settings.dry_multiplier = dry_multiplier
|
||||||
|
gen_settings.dry_allowed_length = params.dry_allowed_length
|
||||||
gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length")
|
gen_settings.dry_base = params.dry_base
|
||||||
gen_settings.dry_base = kwargs.get("dry_base")
|
|
||||||
|
|
||||||
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
|
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
|
||||||
# Use max_seq_len as the fallback to stay consistent
|
# Use max_seq_len as the fallback to stay consistent
|
||||||
gen_settings.dry_range = unwrap(
|
gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len)
|
||||||
kwargs.get("dry_range"), self.config.max_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tokenize sequence breakers
|
# Tokenize sequence breakers
|
||||||
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
|
if params.dry_sequence_breakers:
|
||||||
if dry_sequence_breakers_json:
|
|
||||||
gen_settings.dry_sequence_breakers = {
|
gen_settings.dry_sequence_breakers = {
|
||||||
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json
|
self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize grammar handler
|
# Initialize grammar handler
|
||||||
grammar_handler = ExLlamaV2Grammar()
|
grammar_handler = ExLlamaV2Grammar()
|
||||||
|
|
||||||
# Add JSON schema filter if it exists
|
# Add JSON schema filter if it exists
|
||||||
json_schema = kwargs.get("json_schema")
|
if params.json_schema:
|
||||||
if json_schema:
|
|
||||||
grammar_handler.add_json_schema_filter(
|
grammar_handler.add_json_schema_filter(
|
||||||
json_schema, self.model, self.tokenizer
|
params.json_schema, self.model, self.tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add regex filter if it exists
|
# Add regex filter if it exists
|
||||||
regex_pattern = kwargs.get("regex_pattern")
|
if params.regex_pattern:
|
||||||
if regex_pattern:
|
grammar_handler.add_regex_filter(
|
||||||
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
|
params.regex_pattern, self.model, self.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
# Add EBNF filter if it exists
|
# Add EBNF filter if it exists
|
||||||
grammar_string = kwargs.get("grammar_string")
|
if params.grammar_string:
|
||||||
if grammar_string:
|
grammar_handler.add_kbnf_filter(
|
||||||
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
|
params.grammar_string, self.model, self.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
# Set banned strings
|
# Set banned strings
|
||||||
banned_strings = kwargs.get("banned_strings")
|
banned_strings = params.banned_strings
|
||||||
if banned_strings and len(grammar_handler.filters) > 0:
|
if banned_strings and len(grammar_handler.filters) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Disabling banned_strings because "
|
"Disabling banned_strings because "
|
||||||
@@ -1220,16 +1220,12 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
banned_strings = []
|
banned_strings = []
|
||||||
|
|
||||||
stop_conditions = kwargs.get("stop")
|
stop_conditions = params.stop
|
||||||
add_bos_token = kwargs.get("add_bos_token"), True
|
add_bos_token = params.add_bos_token
|
||||||
ban_eos_token = kwargs.get("ban_eos_token"), False
|
ban_eos_token = params.ban_eos_token
|
||||||
logit_bias = kwargs.get("logit_bias")
|
|
||||||
|
|
||||||
# Logprobs
|
|
||||||
request_logprobs = kwargs.get("logprobs")
|
|
||||||
|
|
||||||
# Speculative Ngram
|
# Speculative Ngram
|
||||||
self.generator.speculative_ngram = kwargs.get("speculative_ngram")
|
self.generator.speculative_ngram = params.speculative_ngram
|
||||||
|
|
||||||
# Override sampler settings for temp = 0
|
# Override sampler settings for temp = 0
|
||||||
if gen_settings.temperature == 0:
|
if gen_settings.temperature == 0:
|
||||||
@@ -1244,17 +1240,15 @@ class ExllamaV2Container:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set banned tokens
|
# Set banned tokens
|
||||||
banned_tokens = kwargs.get("banned_tokens")
|
if params.banned_tokens:
|
||||||
if banned_tokens:
|
gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens)
|
||||||
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
|
|
||||||
|
|
||||||
# Set allowed tokens
|
# Set allowed tokens
|
||||||
allowed_tokens = kwargs.get("allowed_tokens")
|
if params.allowed_tokens:
|
||||||
if allowed_tokens:
|
gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens)
|
||||||
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
|
|
||||||
|
|
||||||
# Set logit bias
|
# Set logit bias
|
||||||
if logit_bias:
|
if params.logit_bias:
|
||||||
# Create a vocab tensor if it doesn't exist for token biasing
|
# Create a vocab tensor if it doesn't exist for token biasing
|
||||||
if gen_settings.token_bias is None:
|
if gen_settings.token_bias is None:
|
||||||
padding = -self.tokenizer.config.vocab_size % 32
|
padding = -self.tokenizer.config.vocab_size % 32
|
||||||
@@ -1264,7 +1258,7 @@ class ExllamaV2Container:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Map logits to the tensor with their biases
|
# Map logits to the tensor with their biases
|
||||||
for token_id, bias in logit_bias.items():
|
for token_id, bias in params.logit_bias.items():
|
||||||
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
|
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
|
||||||
gen_settings.token_bias[token_id] = bias
|
gen_settings.token_bias[token_id] = bias
|
||||||
else:
|
else:
|
||||||
@@ -1289,7 +1283,7 @@ class ExllamaV2Container:
|
|||||||
stop_conditions += eos_tokens
|
stop_conditions += eos_tokens
|
||||||
|
|
||||||
# Get multimodal embeddings if present
|
# Get multimodal embeddings if present
|
||||||
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
|
# TODO: Remove kwargs and pass this as optional
|
||||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||||
|
|
||||||
# Encode both positive and negative prompts
|
# Encode both positive and negative prompts
|
||||||
@@ -1312,7 +1306,7 @@ class ExllamaV2Container:
|
|||||||
# Automatically set max_tokens to fill up the context
|
# Automatically set max_tokens to fill up the context
|
||||||
# This should be an OK default, but may be changed in the future
|
# This should be an OK default, but may be changed in the future
|
||||||
max_tokens = unwrap(
|
max_tokens = unwrap(
|
||||||
kwargs.get("max_tokens"),
|
params.max_tokens,
|
||||||
self.config.max_seq_len - max(context_len, negative_context_len),
|
self.config.max_seq_len - max(context_len, negative_context_len),
|
||||||
)
|
)
|
||||||
if max_tokens < 1:
|
if max_tokens < 1:
|
||||||
@@ -1349,12 +1343,6 @@ class ExllamaV2Container:
|
|||||||
f"is greater than cache_size {self.cache_size}"
|
f"is greater than cache_size {self.cache_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set min_tokens to generate while keeping EOS banned
|
|
||||||
min_tokens = kwargs.get("min_tokens")
|
|
||||||
|
|
||||||
# This is an inverse of skip_special_tokens
|
|
||||||
decode_special_tokens = not kwargs.get("skip_special_tokens")
|
|
||||||
|
|
||||||
# Log prompt to console. Add the BOS token if specified
|
# Log prompt to console. Add the BOS token if specified
|
||||||
log_prompt(
|
log_prompt(
|
||||||
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
|
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
|
||||||
@@ -1369,17 +1357,17 @@ class ExllamaV2Container:
|
|||||||
self.generator,
|
self.generator,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
min_new_tokens=min_tokens,
|
min_new_tokens=params.min_tokens,
|
||||||
gen_settings=gen_settings,
|
gen_settings=gen_settings,
|
||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
decode_special_tokens=decode_special_tokens,
|
decode_special_tokens=not params.skip_special_tokens,
|
||||||
filters=grammar_handler.filters,
|
filters=grammar_handler.filters,
|
||||||
filter_prefer_eos=bool(grammar_handler.filters),
|
filter_prefer_eos=bool(grammar_handler.filters),
|
||||||
return_probs=request_logprobs > 0,
|
return_probs=params.logprobs > 0,
|
||||||
return_top_tokens=request_logprobs,
|
return_top_tokens=params.logprobs,
|
||||||
return_logits=request_logprobs > 0,
|
return_logits=params.logprobs > 0,
|
||||||
banned_strings=banned_strings,
|
banned_strings=banned_strings,
|
||||||
token_healing=token_healing,
|
token_healing=params.token_healing,
|
||||||
identifier=job_id,
|
identifier=job_id,
|
||||||
embeddings=mm_embeddings_content,
|
embeddings=mm_embeddings_content,
|
||||||
)
|
)
|
||||||
@@ -1418,7 +1406,7 @@ class ExllamaV2Container:
|
|||||||
"offset": len(full_response),
|
"offset": len(full_response),
|
||||||
}
|
}
|
||||||
|
|
||||||
if request_logprobs > 0:
|
if params.logprobs > 0:
|
||||||
# Get top tokens and probs
|
# Get top tokens and probs
|
||||||
top_tokens = unwrap(
|
top_tokens = unwrap(
|
||||||
result.get("top_k_tokens"),
|
result.get("top_k_tokens"),
|
||||||
@@ -1494,8 +1482,7 @@ class ExllamaV2Container:
|
|||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
bos_token_id=self.tokenizer.bos_token_id,
|
bos_token_id=self.tokenizer.bos_token_id,
|
||||||
eos_token_id=eos_tokens,
|
eos_token_id=eos_tokens,
|
||||||
**kwargs,
|
**params.model_dump(),
|
||||||
generate_window=generate_window,
|
|
||||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -282,6 +282,11 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
ge=0,
|
ge=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logprobs: Optional[int] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("logprobs", 0),
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("top_k", mode="before")
|
@field_validator("top_k", mode="before")
|
||||||
def convert_top_k(cls, v):
|
def convert_top_k(cls, v):
|
||||||
"""Fixes instance if Top-K is -1."""
|
"""Fixes instance if Top-K is -1."""
|
||||||
|
|||||||
@@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
stream_options: Optional[ChatCompletionStreamOptions] = None
|
stream_options: Optional[ChatCompletionStreamOptions] = None
|
||||||
logprobs: Optional[int] = Field(
|
|
||||||
default_factory=lambda: get_default_sampler_value("logprobs", 0),
|
|
||||||
ge=0,
|
|
||||||
)
|
|
||||||
response_format: Optional[CompletionResponseFormat] = Field(
|
response_format: Optional[CompletionResponseFormat] = Field(
|
||||||
default_factory=CompletionResponseFormat
|
default_factory=CompletionResponseFormat
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -333,11 +333,11 @@ async def stream_generate_chat_completion(
|
|||||||
_stream_collector(
|
_stream_collector(
|
||||||
n,
|
n,
|
||||||
gen_queue,
|
gen_queue,
|
||||||
prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
|
prompt,
|
||||||
|
task_gen_params,
|
||||||
abort_event,
|
abort_event,
|
||||||
embeddings=embeddings,
|
mm_embeddings=embeddings,
|
||||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -422,10 +422,10 @@ async def generate_chat_completion(
|
|||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(
|
model.container.generate(
|
||||||
prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
embeddings=embeddings,
|
prompt,
|
||||||
**data.model_dump(exclude={"prompt"}),
|
data,
|
||||||
|
mm_embeddings=embeddings,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -465,7 +465,6 @@ async def generate_tool_calls(
|
|||||||
# FIXME: May not be necessary depending on how the codebase evolves
|
# FIXME: May not be necessary depending on how the codebase evolves
|
||||||
tool_data = data.model_copy(deep=True)
|
tool_data = data.model_copy(deep=True)
|
||||||
tool_data.json_schema = tool_data.tool_call_schema
|
tool_data.json_schema = tool_data.tool_call_schema
|
||||||
gen_params = tool_data.model_dump()
|
|
||||||
|
|
||||||
for idx, gen in enumerate(generations):
|
for idx, gen in enumerate(generations):
|
||||||
if gen["stop_str"] in tool_data.tool_call_start:
|
if gen["stop_str"] in tool_data.tool_call_start:
|
||||||
@@ -488,10 +487,10 @@ async def generate_tool_calls(
|
|||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(
|
model.container.generate(
|
||||||
pre_tool_prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
|
pre_tool_prompt,
|
||||||
|
tool_data,
|
||||||
embeddings=mm_embeddings,
|
embeddings=mm_embeddings,
|
||||||
**gen_params,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import asyncio
|
|||||||
import pathlib
|
import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from common import model
|
from common import model
|
||||||
from common.auth import get_key_permission
|
from common.auth import get_key_permission
|
||||||
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
from common.networking import (
|
from common.networking import (
|
||||||
get_generator_error,
|
get_generator_error,
|
||||||
handle_request_disconnect,
|
handle_request_disconnect,
|
||||||
@@ -86,16 +86,21 @@ def _create_response(
|
|||||||
async def _stream_collector(
|
async def _stream_collector(
|
||||||
task_idx: int,
|
task_idx: int,
|
||||||
gen_queue: asyncio.Queue,
|
gen_queue: asyncio.Queue,
|
||||||
prompt: str,
|
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
prompt: str,
|
||||||
|
params: CompletionRequest,
|
||||||
abort_event: asyncio.Event,
|
abort_event: asyncio.Event,
|
||||||
**kwargs,
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||||
):
|
):
|
||||||
"""Collects a stream and places results in a common queue"""
|
"""Collects a stream and places results in a common queue"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_generation = model.container.generate_gen(
|
new_generation = model.container.generate_gen(
|
||||||
prompt, request_id, abort_event, **kwargs
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params,
|
||||||
|
abort_event,
|
||||||
|
mm_embeddings,
|
||||||
)
|
)
|
||||||
async for generation in new_generation:
|
async for generation in new_generation:
|
||||||
generation["index"] = task_idx
|
generation["index"] = task_idx
|
||||||
@@ -195,10 +200,10 @@ async def stream_generate_completion(
|
|||||||
_stream_collector(
|
_stream_collector(
|
||||||
n,
|
n,
|
||||||
gen_queue,
|
gen_queue,
|
||||||
data.prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
|
data.prompt,
|
||||||
|
task_gen_params,
|
||||||
abort_event,
|
abort_event,
|
||||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -256,9 +261,9 @@ async def generate_completion(
|
|||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(
|
model.container.generate(
|
||||||
data.prompt,
|
|
||||||
request.state.id,
|
request.state.id,
|
||||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
data.prompt,
|
||||||
|
task_gen_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user