mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-30 19:21:34 +00:00
Merge branch 'main' into draft-split
This commit is contained in:
@@ -1,110 +1,16 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
import typing
|
||||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
|
||||||
from lmformatenforcer import (
|
|
||||||
JsonSchemaParser,
|
|
||||||
RegexParser,
|
|
||||||
TokenEnforcer,
|
|
||||||
CharacterLevelParser,
|
|
||||||
)
|
|
||||||
from lmformatenforcer.integrations.exllamav2 import (
|
|
||||||
build_token_enforcer_tokenizer_data,
|
|
||||||
)
|
|
||||||
from loguru import logger
|
|
||||||
from typing import List
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
class OutlinesTokenizerWrapper:
|
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||||
"""Wrapper for Outlines tokenizer"""
|
from exllamav2.generator.filters import ExLlamaV2Filter
|
||||||
|
from formatron.extractor import NonterminalExtractor
|
||||||
def __init__(self, tokenizer):
|
from formatron.formatter import FormatterBuilder
|
||||||
self.tokenizer = tokenizer
|
from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary
|
||||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
from formatron.schemas import json_schema
|
||||||
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
|
from loguru import logger
|
||||||
self.eos_token_id = self.tokenizer.eos_token_id
|
|
||||||
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
|
|
||||||
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())
|
|
||||||
|
|
||||||
def convert_token_to_string(self, token):
|
|
||||||
return token
|
|
||||||
|
|
||||||
def decode(self, tokens):
|
|
||||||
s = ""
|
|
||||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
|
||||||
for t in tokens:
|
|
||||||
s += id_to_piece[t]
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
|
||||||
"""Filter class for context-free grammar via outlines"""
|
|
||||||
|
|
||||||
def __init__(self, model, tokenizer, grammar):
|
|
||||||
from outlines.fsm.fsm import CFGFSM
|
|
||||||
|
|
||||||
super().__init__(model, tokenizer)
|
|
||||||
|
|
||||||
self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
|
|
||||||
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
|
|
||||||
self.state = self.fsm.first_state
|
|
||||||
|
|
||||||
def begin(self, prefix_str=""):
|
|
||||||
self.state = self.fsm.first_state
|
|
||||||
|
|
||||||
def feed(self, token):
|
|
||||||
self.state = self.fsm.next_state(self.state, token.item())
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
return self.fsm.allowed_token_ids(self.state), set()
|
|
||||||
|
|
||||||
def use_background_worker(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
|
||||||
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
|
|
||||||
return build_token_enforcer_tokenizer_data(tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
|
|
||||||
"""Filter class for LMFE"""
|
|
||||||
|
|
||||||
token_sequence: List[int]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: ExLlamaV2,
|
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
|
||||||
character_level_parser: CharacterLevelParser,
|
|
||||||
):
|
|
||||||
super().__init__(model, tokenizer)
|
|
||||||
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
|
|
||||||
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
|
|
||||||
self.token_sequence = []
|
|
||||||
|
|
||||||
def begin(self, prefix_str: str):
|
|
||||||
self.token_sequence = []
|
|
||||||
|
|
||||||
def feed(self, token):
|
|
||||||
self.token_sequence.append(int(token[0][0]))
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
|
|
||||||
if not hasattr(self, "allow_return_type_list"):
|
|
||||||
return set(allowed_tokens), set()
|
|
||||||
else:
|
|
||||||
return sorted(allowed_tokens), []
|
|
||||||
|
|
||||||
def use_background_worker(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def clear_grammar_func_cache():
|
|
||||||
"""Flush tokenizer_data cache to avoid holding references to
|
|
||||||
tokenizers after unloading a model"""
|
|
||||||
|
|
||||||
_get_lmfe_tokenizer_data.cache_clear()
|
|
||||||
|
|
||||||
|
|
||||||
class ExLlamaV2Grammar:
|
class ExLlamaV2Grammar:
|
||||||
@@ -117,7 +23,7 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
def add_json_schema_filter(
|
def add_json_schema_filter(
|
||||||
self,
|
self,
|
||||||
json_schema: dict,
|
schema: dict,
|
||||||
model: ExLlamaV2,
|
model: ExLlamaV2,
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
):
|
):
|
||||||
@@ -125,7 +31,16 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
# Create the parser
|
# Create the parser
|
||||||
try:
|
try:
|
||||||
schema_parser = JsonSchemaParser(json_schema)
|
# Add fields required by formatron if not present
|
||||||
|
if "$id" not in schema:
|
||||||
|
schema["$id"] = "https://example.com/example.json"
|
||||||
|
if "$schema" not in schema:
|
||||||
|
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
|
||||||
|
|
||||||
|
# Validate schema and create formatter
|
||||||
|
schema = json_schema.create_schema(schema)
|
||||||
|
f = FormatterBuilder()
|
||||||
|
f.append_line(f"{f.json(schema)}")
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -135,14 +50,10 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Allow JSON objects or JSON arrays at the top level
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||||
json_prefixes = ["[", "{"]
|
|
||||||
|
|
||||||
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
|
|
||||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
|
|
||||||
|
|
||||||
# Append the filters
|
# Append the filters
|
||||||
self.filters.extend([lmfilter, prefix_filter])
|
self.filters.append(lmfilter)
|
||||||
|
|
||||||
def add_regex_filter(
|
def add_regex_filter(
|
||||||
self,
|
self,
|
||||||
@@ -154,7 +65,9 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
# Create the parser
|
# Create the parser
|
||||||
try:
|
try:
|
||||||
pattern_parser = RegexParser(pattern)
|
# Validate regex and create formatter
|
||||||
|
f = FormatterBuilder()
|
||||||
|
f.append_line(f"{f.regex(pattern)}")
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -164,32 +77,82 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||||
|
|
||||||
# Append the filters
|
# Append the filters
|
||||||
self.filters.append(lmfilter)
|
self.filters.append(lmfilter)
|
||||||
|
|
||||||
def add_ebnf_filter(
|
def add_kbnf_filter(
|
||||||
self,
|
self,
|
||||||
ebnf_string: str,
|
kbnf_string: str,
|
||||||
model: ExLlamaV2,
|
model: ExLlamaV2,
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
):
|
):
|
||||||
"""
|
"""Adds an ExllamaV2 filter based on KBNF grammar."""
|
||||||
Add an EBNF grammar filter.
|
|
||||||
Possibly replace outlines with an in-house solution in the future.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
# Create the parser
|
||||||
try:
|
try:
|
||||||
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
|
# Validate KBNF and create formatter
|
||||||
except ImportError:
|
f = FormatterBuilder()
|
||||||
|
f.append_line(
|
||||||
|
f"""{f.extractor(lambda nonterminal:
|
||||||
|
CFGExtractor(nonterminal, kbnf_string))}"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Skipping EBNF parsing because Outlines is not installed.\n"
|
"Skipping because the KBNF string couldn't be parsed. "
|
||||||
"Please run the following command in your environment "
|
"Please read the above error for more information."
|
||||||
"to install extra packages:\n"
|
|
||||||
"pip install -U .[extras]"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.filters.append(ebnf_filter)
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||||
|
|
||||||
|
# Append the filters
|
||||||
|
self.filters.append(lmfilter)
|
||||||
|
|
||||||
|
|
||||||
|
class CFGExtractor(NonterminalExtractor):
|
||||||
|
"""Extractor class for KBNF context-free grammar"""
|
||||||
|
|
||||||
|
def __init__(self, nonterminal: str, kbnf_string: str):
|
||||||
|
super().__init__(nonterminal)
|
||||||
|
self.kbnf_string = kbnf_string
|
||||||
|
|
||||||
|
# Return the entire input string as the extracted string
|
||||||
|
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
|
||||||
|
return "", input_str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kbnf_definition(self) -> str:
|
||||||
|
return self.kbnf_string.replace("start", self.nonterminal)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(1)
|
||||||
|
def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer):
|
||||||
|
"""Build and cache engine vocabulary on first grammar run"""
|
||||||
|
|
||||||
|
return create_engine_vocabulary(tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_formatter_filter(
|
||||||
|
model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder
|
||||||
|
) -> ExLlamaV2Filter:
|
||||||
|
"""
|
||||||
|
Create a formatter filter for the ExLlamaV2 engine.
|
||||||
|
Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter
|
||||||
|
with lru_cache enabled for engine vocabulary
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab = _create_cached_engine_vocabulary(tokenizer)
|
||||||
|
f = formatter_builder.build(
|
||||||
|
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))
|
||||||
|
)
|
||||||
|
return FormatterFilter(model, tokenizer, f)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_grammar_func_cache():
|
||||||
|
"""Flush tokenizer_data cache to avoid holding references to
|
||||||
|
tokenizers after unloading a model"""
|
||||||
|
|
||||||
|
_create_cached_engine_vocabulary.cache_clear()
|
||||||
|
|||||||
@@ -498,16 +498,18 @@ class ExllamaV2Container:
|
|||||||
"rope_scale": self.config.scale_pos_emb,
|
"rope_scale": self.config.scale_pos_emb,
|
||||||
"rope_alpha": self.config.scale_alpha_value,
|
"rope_alpha": self.config.scale_alpha_value,
|
||||||
"max_seq_len": self.config.max_seq_len,
|
"max_seq_len": self.config.max_seq_len,
|
||||||
|
"max_batch_size": self.max_batch_size,
|
||||||
"cache_size": self.cache_size,
|
"cache_size": self.cache_size,
|
||||||
"cache_mode": self.cache_mode,
|
"cache_mode": self.cache_mode,
|
||||||
"chunk_size": self.config.max_input_len,
|
"chunk_size": self.config.max_input_len,
|
||||||
"num_experts_per_token": self.config.num_experts_per_token,
|
"num_experts_per_token": self.config.num_experts_per_token,
|
||||||
"prompt_template": self.prompt_template.name
|
|
||||||
if self.prompt_template
|
|
||||||
else None,
|
|
||||||
"use_vision": self.use_vision,
|
"use_vision": self.use_vision,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.prompt_template:
|
||||||
|
model_params["prompt_template"] = self.prompt_template.name
|
||||||
|
model_params["prompt_template_content"] = self.prompt_template.raw_template
|
||||||
|
|
||||||
if self.draft_config:
|
if self.draft_config:
|
||||||
draft_model_params = {
|
draft_model_params = {
|
||||||
"name": self.draft_model_dir.name,
|
"name": self.draft_model_dir.name,
|
||||||
@@ -787,6 +789,10 @@ class ExllamaV2Container:
|
|||||||
max_batch_size=self.max_batch_size,
|
max_batch_size=self.max_batch_size,
|
||||||
paged=self.paged,
|
paged=self.paged,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update the state of the container var
|
||||||
|
if self.max_batch_size is None:
|
||||||
|
self.max_batch_size = self.generator.generator.max_batch_size
|
||||||
finally:
|
finally:
|
||||||
# This means the generator is being recreated
|
# This means the generator is being recreated
|
||||||
# The load lock is already released in the load function
|
# The load lock is already released in the load function
|
||||||
@@ -1222,7 +1228,7 @@ class ExllamaV2Container:
|
|||||||
# Add EBNF filter if it exists
|
# Add EBNF filter if it exists
|
||||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||||
if grammar_string:
|
if grammar_string:
|
||||||
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
|
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
|
||||||
|
|
||||||
# Set banned strings
|
# Set banned strings
|
||||||
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
||||||
@@ -1329,17 +1335,49 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
# The first index will always be the positive prompt
|
# The first index will always be the positive prompt
|
||||||
context_len = input_ids[0].size(dim=-1)
|
context_len = input_ids[0].size(dim=-1)
|
||||||
if context_len > self.config.max_seq_len:
|
|
||||||
raise ValueError(
|
# The second index will be the negative prompt if CFG is enabled
|
||||||
f"Context length {context_len} is greater than max_seq_len "
|
negative_context_len = input_ids[1].size(dim=-1) if negative_prompt else 0
|
||||||
f"{self.config.max_seq_len}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Automatically set max_tokens to fill up the context
|
# Automatically set max_tokens to fill up the context
|
||||||
# This should be an OK default, but may be changed in the future
|
# This should be an OK default, but may be changed in the future
|
||||||
max_tokens = unwrap(
|
max_tokens = unwrap(
|
||||||
kwargs.get("max_tokens"), self.config.max_seq_len - context_len
|
kwargs.get("max_tokens"),
|
||||||
|
self.config.max_seq_len - max(context_len, negative_context_len),
|
||||||
)
|
)
|
||||||
|
if max_tokens < 1:
|
||||||
|
logger.warning("max_tokens must be a positive integer, setting to 1.")
|
||||||
|
max_tokens = 1
|
||||||
|
|
||||||
|
# Determine if the negative context or the context length is bigger
|
||||||
|
context_to_check = max(negative_context_len, context_len)
|
||||||
|
|
||||||
|
# Check highest possible total length of request
|
||||||
|
if context_to_check + max_tokens > self.config.max_seq_len:
|
||||||
|
preamble = (
|
||||||
|
"Negative prompt request"
|
||||||
|
if negative_context_len > context_len
|
||||||
|
else "Request"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"{preamble} length {context_to_check} + {max_tokens} is greater than "
|
||||||
|
f"max_seq_len {self.config.max_seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check total required pages for CFG request to avoid overallocation
|
||||||
|
if negative_prompt and (
|
||||||
|
sum(
|
||||||
|
256 * math.ceil((context + max_tokens) / 256)
|
||||||
|
for context in (context_len, negative_context_len)
|
||||||
|
)
|
||||||
|
> self.cache_size
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Total required page size for request "
|
||||||
|
f"{context_len} + {negative_context_len} + {max_tokens} * 2 "
|
||||||
|
f"is greater than cache_size {self.cache_size}"
|
||||||
|
)
|
||||||
|
|
||||||
# Set min_tokens to generate while keeping EOS banned
|
# Set min_tokens to generate while keeping EOS banned
|
||||||
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
|
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ def _log_formatter(record: dict):
|
|||||||
"ERROR": "red",
|
"ERROR": "red",
|
||||||
"CRITICAL": "bold white on red",
|
"CRITICAL": "bold white on red",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
time = record.get("time")
|
||||||
|
colored_time = f"[grey37]{time:YYYY-DD-MM HH:mm:ss.SSS}[/grey37]"
|
||||||
|
|
||||||
level = record.get("level")
|
level = record.get("level")
|
||||||
level_color = color_map.get(level.name, "cyan")
|
level_color = color_map.get(level.name, "cyan")
|
||||||
colored_level = f"[{level_color}]{level.name}[/{level_color}]:"
|
colored_level = f"[{level_color}]{level.name}[/{level_color}]:"
|
||||||
@@ -69,9 +73,11 @@ def _log_formatter(record: dict):
|
|||||||
|
|
||||||
fmt = ""
|
fmt = ""
|
||||||
if len(lines) > 1:
|
if len(lines) > 1:
|
||||||
fmt = "\n".join([f"{colored_level}{separator}{line}" for line in lines])
|
fmt = "\n".join(
|
||||||
|
[f"{colored_time} {colored_level}{separator}{line}" for line in lines]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
fmt = f"{colored_level}{separator}{message}"
|
fmt = f"{colored_time} {colored_level}{separator}{message}"
|
||||||
|
|
||||||
return fmt
|
return fmt
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from typing import List
|
|
||||||
from backends.exllamav2.vision import get_image_embedding
|
from backends.exllamav2.vision import get_image_embedding
|
||||||
from common import model
|
from common import model
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from common.optional_dependencies import dependencies
|
from common.optional_dependencies import dependencies
|
||||||
|
|
||||||
@@ -9,12 +10,12 @@ if dependencies.exllamav2:
|
|||||||
from exllamav2 import ExLlamaV2VisionTower
|
from exllamav2 import ExLlamaV2VisionTower
|
||||||
|
|
||||||
|
|
||||||
class MultimodalEmbeddingWrapper:
|
class MultimodalEmbeddingWrapper(BaseModel):
|
||||||
"""Common multimodal embedding wrapper"""
|
"""Common multimodal embedding wrapper"""
|
||||||
|
|
||||||
type: str = None
|
type: str = None
|
||||||
content: List = []
|
content: list = Field(default_factory=list)
|
||||||
text_alias: List[str] = []
|
text_alias: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
async def add(self, url: str):
|
async def add(self, url: str):
|
||||||
# Determine the type of vision embedding to use
|
# Determine the type of vision embedding to use
|
||||||
|
|||||||
@@ -14,14 +14,13 @@ class DependenciesModel(BaseModel):
|
|||||||
torch: bool
|
torch: bool
|
||||||
exllamav2: bool
|
exllamav2: bool
|
||||||
flash_attn: bool
|
flash_attn: bool
|
||||||
outlines: bool
|
|
||||||
infinity_emb: bool
|
infinity_emb: bool
|
||||||
sentence_transformers: bool
|
sentence_transformers: bool
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def extras(self) -> bool:
|
def extras(self) -> bool:
|
||||||
return self.outlines and self.infinity_emb and self.sentence_transformers
|
return self.infinity_emb and self.sentence_transformers
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
|
|
||||||
max_tokens: Optional[int] = Field(
|
max_tokens: Optional[int] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
||||||
validation_alias=AliasChoices("max_tokens", "max_length"),
|
validation_alias=AliasChoices(
|
||||||
|
"max_tokens", "max_completion_tokens", "max_length"
|
||||||
|
),
|
||||||
description="Aliases: max_length",
|
description="Aliases: max_length",
|
||||||
examples=[150],
|
examples=[150],
|
||||||
ge=0,
|
ge=0,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from sse_starlette import ServerSentEvent
|
from sse_starlette.event import ServerSentEvent
|
||||||
|
|
||||||
from common import model
|
from common import model
|
||||||
from common.networking import (
|
from common.networking import (
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class EmbeddingsRequest(BaseModel):
|
|||||||
|
|
||||||
class EmbeddingObject(BaseModel):
|
class EmbeddingObject(BaseModel):
|
||||||
object: str = Field("embedding", description="Type of the object.")
|
object: str = Field("embedding", description="Type of the object.")
|
||||||
embedding: List[float] = Field(
|
embedding: Union[List[float], str] = Field(
|
||||||
..., description="Embedding values as a list of floats."
|
..., description="Embedding values as a list of floats."
|
||||||
)
|
)
|
||||||
index: int = Field(
|
index: int = Field(
|
||||||
|
|||||||
@@ -454,16 +454,23 @@ async def generate_tool_calls(
|
|||||||
if gen["stop_str"] in tool_data.tool_call_start:
|
if gen["stop_str"] in tool_data.tool_call_start:
|
||||||
if "text" in gen:
|
if "text" in gen:
|
||||||
# non streaming, all generations will have the text they generated
|
# non streaming, all generations will have the text they generated
|
||||||
pre_tool_prompt = await apply_chat_template(data, gen["text"])
|
pre_tool_prompt, mm_embeddings = await apply_chat_template(
|
||||||
|
data, gen["text"]
|
||||||
|
)
|
||||||
elif current_generations is not None:
|
elif current_generations is not None:
|
||||||
# streaming, we wont have text in the generation,
|
# streaming, we wont have text in the generation,
|
||||||
# we'll have to use the current_generations
|
# we'll have to use the current_generations
|
||||||
pre_tool_prompt = await apply_chat_template(data, current_generations)
|
pre_tool_prompt, mm_embeddings = await apply_chat_template(
|
||||||
|
data, current_generations
|
||||||
|
)
|
||||||
|
|
||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(
|
model.container.generate(
|
||||||
pre_tool_prompt, request.state.id, **gen_params
|
pre_tool_prompt,
|
||||||
|
request.state.id,
|
||||||
|
embeddings=mm_embeddings,
|
||||||
|
**gen_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from sys import maxsize
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from common.multimodal import MultimodalEmbeddingWrapper
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
from common import model, sampling
|
from common import model, sampling
|
||||||
@@ -22,9 +23,11 @@ from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadRespons
|
|||||||
from endpoints.core.types.model import (
|
from endpoints.core.types.model import (
|
||||||
EmbeddingModelLoadRequest,
|
EmbeddingModelLoadRequest,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
|
ModelDefaultGenerationSettings,
|
||||||
ModelList,
|
ModelList,
|
||||||
ModelLoadRequest,
|
ModelLoadRequest,
|
||||||
ModelLoadResponse,
|
ModelLoadResponse,
|
||||||
|
ModelPropsResponse,
|
||||||
)
|
)
|
||||||
from endpoints.core.types.health import HealthCheckResponse
|
from endpoints.core.types.health import HealthCheckResponse
|
||||||
from endpoints.core.types.sampler_overrides import (
|
from endpoints.core.types.sampler_overrides import (
|
||||||
@@ -65,6 +68,34 @@ async def healthcheck(response: Response) -> HealthCheckResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/.well-known/serviceinfo")
|
||||||
|
async def service_info():
|
||||||
|
return JSONResponse(
|
||||||
|
content={
|
||||||
|
"version": 0.1,
|
||||||
|
"software": {
|
||||||
|
"name": "TabbyAPI",
|
||||||
|
"repository": "https://github.com/theroyallab/tabbyAPI",
|
||||||
|
"homepage": "https://github.com/theroyallab/tabbyAPI",
|
||||||
|
},
|
||||||
|
"api": {
|
||||||
|
"openai": {
|
||||||
|
"name": "OpenAI API",
|
||||||
|
"relative_url": "/v1",
|
||||||
|
"documentation": "https://theroyallab.github.io/tabbyAPI",
|
||||||
|
"version": 1,
|
||||||
|
},
|
||||||
|
"koboldai": {
|
||||||
|
"name": "KoboldAI API",
|
||||||
|
"relative_url": "/api",
|
||||||
|
"documentation": "https://theroyallab.github.io/tabbyAPI",
|
||||||
|
"version": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Model list endpoint
|
# Model list endpoint
|
||||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||||
@@ -102,6 +133,30 @@ async def current_model() -> ModelCard:
|
|||||||
return get_current_model()
|
return get_current_model()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/props", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
||||||
|
)
|
||||||
|
async def model_props() -> ModelPropsResponse:
|
||||||
|
"""
|
||||||
|
Returns specific properties of a model for clients.
|
||||||
|
|
||||||
|
To get all properties, use /v1/model instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_model_card = get_current_model()
|
||||||
|
resp = ModelPropsResponse(
|
||||||
|
total_slots=current_model_card.parameters.max_batch_size,
|
||||||
|
default_generation_settings=ModelDefaultGenerationSettings(
|
||||||
|
n_ctx=current_model_card.parameters.max_seq_len,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_model_card.parameters.prompt_template_content:
|
||||||
|
resp.chat_template = current_model_card.parameters.prompt_template_content
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||||
async def list_draft_models(request: Request) -> ModelList:
|
async def list_draft_models(request: Request) -> ModelList:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -16,10 +16,12 @@ class ModelCardParameters(BaseModel):
|
|||||||
max_seq_len: Optional[int] = None
|
max_seq_len: Optional[int] = None
|
||||||
rope_scale: Optional[float] = 1.0
|
rope_scale: Optional[float] = 1.0
|
||||||
rope_alpha: Optional[float] = 1.0
|
rope_alpha: Optional[float] = 1.0
|
||||||
|
max_batch_size: Optional[int] = 1
|
||||||
cache_size: Optional[int] = None
|
cache_size: Optional[int] = None
|
||||||
cache_mode: Optional[str] = "FP16"
|
cache_mode: Optional[str] = "FP16"
|
||||||
chunk_size: Optional[int] = 2048
|
chunk_size: Optional[int] = 2048
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
|
prompt_template_content: Optional[str] = None
|
||||||
num_experts_per_token: Optional[int] = None
|
num_experts_per_token: Optional[int] = None
|
||||||
use_vision: Optional[bool] = False
|
use_vision: Optional[bool] = False
|
||||||
|
|
||||||
@@ -139,3 +141,17 @@ class ModelLoadResponse(BaseModel):
|
|||||||
module: int
|
module: int
|
||||||
modules: int
|
modules: int
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDefaultGenerationSettings(BaseModel):
|
||||||
|
"""Contains default generation settings for model props."""
|
||||||
|
|
||||||
|
n_ctx: int
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPropsResponse(BaseModel):
|
||||||
|
"""Represents a model props response."""
|
||||||
|
|
||||||
|
total_slots: int = 1
|
||||||
|
chat_template: str = ""
|
||||||
|
default_generation_settings: ModelDefaultGenerationSettings
|
||||||
|
|||||||
@@ -16,30 +16,30 @@ version = "0.0.1"
|
|||||||
description = "An OAI compatible exllamav2 API that's both lightweight and fast"
|
description = "An OAI compatible exllamav2 API that's both lightweight and fast"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi-slim >= 0.110.0",
|
"fastapi-slim >= 0.115",
|
||||||
"pydantic >= 2.0.0",
|
"pydantic >= 2.0.0",
|
||||||
"ruamel.yaml",
|
"ruamel.yaml",
|
||||||
"rich",
|
"rich",
|
||||||
"uvicorn >= 0.28.1",
|
"uvicorn >= 0.28.1",
|
||||||
"jinja2 >= 3.0.0",
|
"jinja2 >= 3.0.0",
|
||||||
"loguru",
|
"loguru",
|
||||||
"sse-starlette",
|
"sse-starlette >= 2.2.0",
|
||||||
"packaging",
|
"packaging",
|
||||||
"tokenizers",
|
"tokenizers >= 0.21.0",
|
||||||
"lm-format-enforcer >= 0.9.6",
|
"formatron >= 0.4.11",
|
||||||
|
"kbnf >= 0.4.1",
|
||||||
"aiofiles",
|
"aiofiles",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"async_lru",
|
"async_lru",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
"psutil",
|
"psutil",
|
||||||
"httptools>=0.5.0",
|
"httptools >= 0.5.0",
|
||||||
"pillow",
|
"pillow",
|
||||||
|
|
||||||
# Improved asyncio loops
|
# Improved asyncio loops
|
||||||
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||||
"winloop ; platform_system == 'Windows'",
|
"winloop ; platform_system == 'Windows'",
|
||||||
|
|
||||||
# TEMP: Remove once 2.x is fixed in upstream
|
|
||||||
"numpy < 2.0.0",
|
"numpy < 2.0.0",
|
||||||
|
|
||||||
# For python 3.12
|
# For python 3.12
|
||||||
@@ -53,7 +53,6 @@ dependencies = [
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
extras = [
|
extras = [
|
||||||
# Heavy dependencies that aren't for everyday use
|
# Heavy dependencies that aren't for everyday use
|
||||||
"outlines",
|
|
||||||
"infinity-emb",
|
"infinity-emb",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
]
|
]
|
||||||
@@ -62,68 +61,46 @@ dev = [
|
|||||||
]
|
]
|
||||||
cu121 = [
|
cu121 = [
|
||||||
# Torch (Extra index URLs not support in pyproject.toml)
|
# Torch (Extra index URLs not support in pyproject.toml)
|
||||||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
"torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
"torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
"torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
|
|
||||||
# Exl2
|
# Exl2
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
|
|
||||||
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
|
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
|
||||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu124torch2.5.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu124torch2.5.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||||
|
|
||||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.5cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||||
]
|
|
||||||
cu118 = [
|
|
||||||
# Torch
|
|
||||||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
|
||||||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
|
||||||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
|
||||||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
|
||||||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
|
||||||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
|
||||||
|
|
||||||
# Exl2
|
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
|
||||||
|
|
||||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
|
||||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
|
||||||
]
|
]
|
||||||
amd = [
|
amd = [
|
||||||
# Torch triton for ROCm
|
# Torch triton for ROCm
|
||||||
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.1.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||||
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.1.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||||
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.1.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||||
|
|
||||||
# Torch
|
# Torch
|
||||||
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.4.1%2Brocm6.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
"torch @ https://download.pytorch.org/whl/rocm6.2/torch-2.5.1%2Brocm6.2-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||||
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.4.1%2Brocm6.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
"torch @ https://download.pytorch.org/whl/rocm6.2/torch-2.5.1%2Brocm6.2-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||||
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.4.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
"torch @ https://download.pytorch.org/whl/rocm6.2/torch-2.5.1%2Brocm6.2-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||||
|
|
||||||
# Exl2
|
# Exl2
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+rocm6.1.torch2.4.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+rocm6.2.torch2.5.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+rocm6.1.torch2.4.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+rocm6.2.torch2.5.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+rocm6.1.torch2.4.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+rocm6.2.torch2.5.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||||
]
|
]
|
||||||
|
|
||||||
# MARK: Ruff options
|
# MARK: Ruff options
|
||||||
|
|||||||
2
start.py
2
start.py
@@ -47,7 +47,7 @@ def get_install_features(lib_name: str = None):
|
|||||||
# Ask the user for the GPU lib
|
# Ask the user for the GPU lib
|
||||||
gpu_lib_choices = {
|
gpu_lib_choices = {
|
||||||
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
|
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
|
||||||
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
|
"B": {"pretty": "NVIDIA Cuda 11.8 (Unsupported)", "internal": "cu118"},
|
||||||
"C": {"pretty": "AMD", "internal": "amd"},
|
"C": {"pretty": "AMD", "internal": "amd"},
|
||||||
}
|
}
|
||||||
user_input = get_user_choice(
|
user_input = get_user_choice(
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
{%- set start_header = "<|start_header_id|>" -%}
|
{%- set start_header = "<|start_header_id|>" -%}
|
||||||
{%- set end_header = "<|end_header_id|>\n" -%}
|
{%- set end_header = "<|end_header_id|>\n" -%}
|
||||||
|
|
||||||
{%- set example_tool_call -%}[
|
{%- set example_tool_call = '[
|
||||||
{
|
{
|
||||||
"id": "tool_id_1342",
|
"id": "tool_id_1342",
|
||||||
"function": {
|
"function": {
|
||||||
@@ -23,29 +23,26 @@
|
|||||||
},
|
},
|
||||||
"type": "function"
|
"type": "function"
|
||||||
}
|
}
|
||||||
]
|
]' -%}
|
||||||
{%- endset -%}
|
|
||||||
|
|
||||||
{%- set inital_system_prompt -%}You are an assistant that has access to the following set of tools, to call a tool:
|
{%- set inital_system_prompt = 'You are an assistant that has access to the following set of tools, to call a tool:
|
||||||
1. Prefix calls with '{{ tool_start }}' and end calls with '{{ tool_end }}'
|
1. Prefix calls with ' + tool_start + ' and end calls with ' + tool_end + '
|
||||||
2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be.
|
2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be.
|
||||||
3. Generate all calls using the following json tool call format. Here is a multi tool call example:
|
3. Generate all calls using the following json tool call format. Here is a multi tool call example:
|
||||||
|
|
||||||
{{ tool_start }}{{ example_tool_call }}{{ tool_end }}
|
' + tool_start + example_tool_call + tool_end + '
|
||||||
|
|
||||||
Here are the tools available for you to call:
|
Here are the tools available for you to call:
|
||||||
{{ tools_json }}
|
' + tools_json -%}
|
||||||
{%- endset -%}
|
|
||||||
|
|
||||||
{%- set tool_reminder -%}Available Tools:
|
{%- set tool_reminder = 'Available Tools:
|
||||||
{{ tools_json }}
|
' + tools_json + '
|
||||||
|
|
||||||
Tool Call Format Example:
|
Tool Call Format Example:
|
||||||
{{ tool_start }}{{ example_tool_call }}
|
' + tool_start + example_tool_call + '
|
||||||
|
|
||||||
Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}.
|
Prefix & Suffix: Begin tool calls with ' + tool_start + ' and end with ' + tool_end + '.
|
||||||
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).
|
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).' -%}
|
||||||
{%- endset -%}
|
|
||||||
|
|
||||||
{# Template #}
|
{# Template #}
|
||||||
|
|
||||||
@@ -54,15 +51,15 @@ Argument Types: Use correct data types for arguments (e.g., strings in quotes, n
|
|||||||
{%- if role not in message_roles -%}
|
{%- if role not in message_roles -%}
|
||||||
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }}
|
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
|
|
||||||
{%- set content = message['content'] | default('', true) | trim -%}
|
{%- set content = message['content'] if message['content'] is defined else '' | trim -%}
|
||||||
{%- if loop.first -%}
|
{%- if loop.first -%}
|
||||||
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
|
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
|
||||||
{{ inital_system_prompt }}
|
{{ inital_system_prompt }}
|
||||||
|
|
||||||
{{ content }}{{ eos_token }}
|
{{ content }}{{ eos_token }}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
|
|
||||||
{%- if not loop.first -%}
|
{%- if not loop.first -%}
|
||||||
{{ start_header }}{{ role }}{{ end_header }}
|
{{ start_header }}{{ role }}{{ end_header }}
|
||||||
{{ content }}
|
{{ content }}
|
||||||
@@ -81,4 +78,4 @@ Argument Types: Use correct data types for arguments (e.g., strings in quotes, n
|
|||||||
{{ tool_precursor }}{{ tool_start }}
|
{{ tool_precursor }}{{ tool_start }}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{{ start_header }}assistant{{ end_header }}
|
{{ start_header }}assistant{{ end_header }}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
|
|||||||
Reference in New Issue
Block a user