Merge branch 'main' into draft-split

This commit is contained in:
kingbri
2025-02-08 15:10:44 -05:00
14 changed files with 287 additions and 226 deletions

View File

@@ -1,110 +1,16 @@
import traceback import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer import typing
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import (
JsonSchemaParser,
RegexParser,
TokenEnforcer,
CharacterLevelParser,
)
from lmformatenforcer.integrations.exllamav2 import (
build_token_enforcer_tokenizer_data,
)
from loguru import logger
from typing import List
from functools import lru_cache from functools import lru_cache
from typing import List
import torch
class OutlinesTokenizerWrapper: from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
"""Wrapper for Outlines tokenizer""" from exllamav2.generator.filters import ExLlamaV2Filter
from formatron.extractor import NonterminalExtractor
def __init__(self, tokenizer): from formatron.formatter import FormatterBuilder
self.tokenizer = tokenizer from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary
id_to_piece = self.tokenizer.get_id_to_piece_list() from formatron.schemas import json_schema
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)} from loguru import logger
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())
def convert_token_to_string(self, token):
return token
def decode(self, tokens):
s = ""
id_to_piece = self.tokenizer.get_id_to_piece_list()
for t in tokens:
s += id_to_piece[t]
return s
class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
"""Filter class for context-free grammar via outlines"""
def __init__(self, model, tokenizer, grammar):
from outlines.fsm.fsm import CFGFSM
super().__init__(model, tokenizer)
self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
self.state = self.fsm.first_state
def begin(self, prefix_str=""):
self.state = self.fsm.first_state
def feed(self, token):
self.state = self.fsm.next_state(self.state, token.item())
def next(self):
return self.fsm.allowed_token_ids(self.state), set()
def use_background_worker(self):
return True
@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
"""Filter class for LMFE"""
token_sequence: List[int]
def __init__(
self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
character_level_parser: CharacterLevelParser,
):
super().__init__(model, tokenizer)
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
self.token_sequence = []
def begin(self, prefix_str: str):
self.token_sequence = []
def feed(self, token):
self.token_sequence.append(int(token[0][0]))
def next(self):
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
if not hasattr(self, "allow_return_type_list"):
return set(allowed_tokens), set()
else:
return sorted(allowed_tokens), []
def use_background_worker(self):
return True
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
_get_lmfe_tokenizer_data.cache_clear()
class ExLlamaV2Grammar: class ExLlamaV2Grammar:
@@ -117,7 +23,7 @@ class ExLlamaV2Grammar:
def add_json_schema_filter( def add_json_schema_filter(
self, self,
json_schema: dict, schema: dict,
model: ExLlamaV2, model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer, tokenizer: ExLlamaV2Tokenizer,
): ):
@@ -125,7 +31,16 @@ class ExLlamaV2Grammar:
# Create the parser # Create the parser
try: try:
schema_parser = JsonSchemaParser(json_schema) # Add fields required by formatron if not present
if "$id" not in schema:
schema["$id"] = "https://example.com/example.json"
if "$schema" not in schema:
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
# Validate schema and create formatter
schema = json_schema.create_schema(schema)
f = FormatterBuilder()
f.append_line(f"{f.json(schema)}")
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
logger.error( logger.error(
@@ -135,14 +50,10 @@ class ExLlamaV2Grammar:
return return
# Allow JSON objects or JSON arrays at the top level lmfilter = _create_formatter_filter(model, tokenizer, f)
json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters # Append the filters
self.filters.extend([lmfilter, prefix_filter]) self.filters.append(lmfilter)
def add_regex_filter( def add_regex_filter(
self, self,
@@ -154,7 +65,9 @@ class ExLlamaV2Grammar:
# Create the parser # Create the parser
try: try:
pattern_parser = RegexParser(pattern) # Validate regex and create formatter
f = FormatterBuilder()
f.append_line(f"{f.regex(pattern)}")
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
logger.error( logger.error(
@@ -164,32 +77,82 @@ class ExLlamaV2Grammar:
return return
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser) lmfilter = _create_formatter_filter(model, tokenizer, f)
# Append the filters # Append the filters
self.filters.append(lmfilter) self.filters.append(lmfilter)
def add_ebnf_filter( def add_kbnf_filter(
self, self,
ebnf_string: str, kbnf_string: str,
model: ExLlamaV2, model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer, tokenizer: ExLlamaV2Tokenizer,
): ):
""" """Adds an ExllamaV2 filter based on KBNF grammar."""
Add an EBNF grammar filter.
Possibly replace outlines with an in-house solution in the future.
"""
# Create the parser
try: try:
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string) # Validate KBNF and create formatter
except ImportError: f = FormatterBuilder()
f.append_line(
f"""{f.extractor(lambda nonterminal:
CFGExtractor(nonterminal, kbnf_string))}"""
)
except Exception:
logger.error( logger.error(
"Skipping EBNF parsing because Outlines is not installed.\n" "Skipping because the KBNF string couldn't be parsed. "
"Please run the following command in your environment " "Please read the above error for more information."
"to install extra packages:\n"
"pip install -U .[extras]"
) )
return return
self.filters.append(ebnf_filter) lmfilter = _create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
class CFGExtractor(NonterminalExtractor):
"""Extractor class for KBNF context-free grammar"""
def __init__(self, nonterminal: str, kbnf_string: str):
super().__init__(nonterminal)
self.kbnf_string = kbnf_string
# Return the entire input string as the extracted string
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
return "", input_str
@property
def kbnf_definition(self) -> str:
return self.kbnf_string.replace("start", self.nonterminal)
@lru_cache(1)
def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer):
"""Build and cache engine vocabulary on first grammar run"""
return create_engine_vocabulary(tokenizer)
def _create_formatter_filter(
model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder
) -> ExLlamaV2Filter:
"""
Create a formatter filter for the ExLlamaV2 engine.
Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter
with lru_cache enabled for engine vocabulary
"""
vocab = _create_cached_engine_vocabulary(tokenizer)
f = formatter_builder.build(
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))
)
return FormatterFilter(model, tokenizer, f)
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
_create_cached_engine_vocabulary.cache_clear()

View File

@@ -498,16 +498,18 @@ class ExllamaV2Container:
"rope_scale": self.config.scale_pos_emb, "rope_scale": self.config.scale_pos_emb,
"rope_alpha": self.config.scale_alpha_value, "rope_alpha": self.config.scale_alpha_value,
"max_seq_len": self.config.max_seq_len, "max_seq_len": self.config.max_seq_len,
"max_batch_size": self.max_batch_size,
"cache_size": self.cache_size, "cache_size": self.cache_size,
"cache_mode": self.cache_mode, "cache_mode": self.cache_mode,
"chunk_size": self.config.max_input_len, "chunk_size": self.config.max_input_len,
"num_experts_per_token": self.config.num_experts_per_token, "num_experts_per_token": self.config.num_experts_per_token,
"prompt_template": self.prompt_template.name
if self.prompt_template
else None,
"use_vision": self.use_vision, "use_vision": self.use_vision,
} }
if self.prompt_template:
model_params["prompt_template"] = self.prompt_template.name
model_params["prompt_template_content"] = self.prompt_template.raw_template
if self.draft_config: if self.draft_config:
draft_model_params = { draft_model_params = {
"name": self.draft_model_dir.name, "name": self.draft_model_dir.name,
@@ -787,6 +789,10 @@ class ExllamaV2Container:
max_batch_size=self.max_batch_size, max_batch_size=self.max_batch_size,
paged=self.paged, paged=self.paged,
) )
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally: finally:
# This means the generator is being recreated # This means the generator is being recreated
# The load lock is already released in the load function # The load lock is already released in the load function
@@ -1222,7 +1228,7 @@ class ExllamaV2Container:
# Add EBNF filter if it exists # Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string")) grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string: if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
# Set banned strings # Set banned strings
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
@@ -1329,17 +1335,49 @@ class ExllamaV2Container:
# The first index will always be the positive prompt # The first index will always be the positive prompt
context_len = input_ids[0].size(dim=-1) context_len = input_ids[0].size(dim=-1)
if context_len > self.config.max_seq_len:
raise ValueError( # The second index will be the negative prompt if CFG is enabled
f"Context length {context_len} is greater than max_seq_len " negative_context_len = input_ids[1].size(dim=-1) if negative_prompt else 0
f"{self.config.max_seq_len}"
)
# Automatically set max_tokens to fill up the context # Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future # This should be an OK default, but may be changed in the future
max_tokens = unwrap( max_tokens = unwrap(
kwargs.get("max_tokens"), self.config.max_seq_len - context_len kwargs.get("max_tokens"),
self.config.max_seq_len - max(context_len, negative_context_len),
) )
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Determine if the negative context or the context length is bigger
context_to_check = max(negative_context_len, context_len)
# Check highest possible total length of request
if context_to_check + max_tokens > self.config.max_seq_len:
preamble = (
"Negative prompt request"
if negative_context_len > context_len
else "Request"
)
raise ValueError(
f"{preamble} length {context_to_check} + {max_tokens} is greater than "
f"max_seq_len {self.config.max_seq_len}"
)
# Check total required pages for CFG request to avoid overallocation
if negative_prompt and (
sum(
256 * math.ceil((context + max_tokens) / 256)
for context in (context_len, negative_context_len)
)
> self.cache_size
):
raise ValueError(
f"Total required page size for request "
f"{context_len} + {negative_context_len} + {max_tokens} * 2 "
f"is greater than cache_size {self.cache_size}"
)
# Set min_tokens to generate while keeping EOS banned # Set min_tokens to generate while keeping EOS banned
min_tokens = unwrap(kwargs.get("min_tokens"), 0) min_tokens = unwrap(kwargs.get("min_tokens"), 0)

View File

@@ -52,6 +52,10 @@ def _log_formatter(record: dict):
"ERROR": "red", "ERROR": "red",
"CRITICAL": "bold white on red", "CRITICAL": "bold white on red",
} }
time = record.get("time")
colored_time = f"[grey37]{time:YYYY-DD-MM HH:mm:ss.SSS}[/grey37]"
level = record.get("level") level = record.get("level")
level_color = color_map.get(level.name, "cyan") level_color = color_map.get(level.name, "cyan")
colored_level = f"[{level_color}]{level.name}[/{level_color}]:" colored_level = f"[{level_color}]{level.name}[/{level_color}]:"
@@ -69,9 +73,11 @@ def _log_formatter(record: dict):
fmt = "" fmt = ""
if len(lines) > 1: if len(lines) > 1:
fmt = "\n".join([f"{colored_level}{separator}{line}" for line in lines]) fmt = "\n".join(
[f"{colored_time} {colored_level}{separator}{line}" for line in lines]
)
else: else:
fmt = f"{colored_level}{separator}{message}" fmt = f"{colored_time} {colored_level}{separator}{message}"
return fmt return fmt

View File

@@ -1,7 +1,8 @@
from typing import List
from backends.exllamav2.vision import get_image_embedding from backends.exllamav2.vision import get_image_embedding
from common import model from common import model
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field
from typing import List
from common.optional_dependencies import dependencies from common.optional_dependencies import dependencies
@@ -9,12 +10,12 @@ if dependencies.exllamav2:
from exllamav2 import ExLlamaV2VisionTower from exllamav2 import ExLlamaV2VisionTower
class MultimodalEmbeddingWrapper: class MultimodalEmbeddingWrapper(BaseModel):
"""Common multimodal embedding wrapper""" """Common multimodal embedding wrapper"""
type: str = None type: str = None
content: List = [] content: list = Field(default_factory=list)
text_alias: List[str] = [] text_alias: List[str] = Field(default_factory=list)
async def add(self, url: str): async def add(self, url: str):
# Determine the type of vision embedding to use # Determine the type of vision embedding to use

View File

@@ -14,14 +14,13 @@ class DependenciesModel(BaseModel):
torch: bool torch: bool
exllamav2: bool exllamav2: bool
flash_attn: bool flash_attn: bool
outlines: bool
infinity_emb: bool infinity_emb: bool
sentence_transformers: bool sentence_transformers: bool
@computed_field @computed_field
@property @property
def extras(self) -> bool: def extras(self) -> bool:
return self.outlines and self.infinity_emb and self.sentence_transformers return self.infinity_emb and self.sentence_transformers
@computed_field @computed_field
@property @property

View File

@@ -25,7 +25,9 @@ class BaseSamplerRequest(BaseModel):
max_tokens: Optional[int] = Field( max_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens"), default_factory=lambda: get_default_sampler_value("max_tokens"),
validation_alias=AliasChoices("max_tokens", "max_length"), validation_alias=AliasChoices(
"max_tokens", "max_completion_tokens", "max_length"
),
description="Aliases: max_length", description="Aliases: max_length",
examples=[150], examples=[150],
ge=0, ge=0,

View File

@@ -2,7 +2,7 @@ import asyncio
from asyncio import CancelledError from asyncio import CancelledError
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from loguru import logger from loguru import logger
from sse_starlette import ServerSentEvent from sse_starlette.event import ServerSentEvent
from common import model from common import model
from common.networking import ( from common.networking import (

View File

@@ -27,7 +27,7 @@ class EmbeddingsRequest(BaseModel):
class EmbeddingObject(BaseModel): class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.") object: str = Field("embedding", description="Type of the object.")
embedding: List[float] = Field( embedding: Union[List[float], str] = Field(
..., description="Embedding values as a list of floats." ..., description="Embedding values as a list of floats."
) )
index: int = Field( index: int = Field(

View File

@@ -454,16 +454,23 @@ async def generate_tool_calls(
if gen["stop_str"] in tool_data.tool_call_start: if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen: if "text" in gen:
# non streaming, all generations will have the text they generated # non streaming, all generations will have the text they generated
pre_tool_prompt = await apply_chat_template(data, gen["text"]) pre_tool_prompt, mm_embeddings = await apply_chat_template(
data, gen["text"]
)
elif current_generations is not None: elif current_generations is not None:
# streaming, we wont have text in the generation, # streaming, we wont have text in the generation,
# we'll have to use the current_generations # we'll have to use the current_generations
pre_tool_prompt = await apply_chat_template(data, current_generations) pre_tool_prompt, mm_embeddings = await apply_chat_template(
data, current_generations
)
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate( model.container.generate(
pre_tool_prompt, request.state.id, **gen_params pre_tool_prompt,
request.state.id,
embeddings=mm_embeddings,
**gen_params,
) )
) )
) )

View File

@@ -4,6 +4,7 @@ from sys import maxsize
from typing import Optional from typing import Optional
from common.multimodal import MultimodalEmbeddingWrapper from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import JSONResponse
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from common import model, sampling from common import model, sampling
@@ -22,9 +23,11 @@ from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadRespons
from endpoints.core.types.model import ( from endpoints.core.types.model import (
EmbeddingModelLoadRequest, EmbeddingModelLoadRequest,
ModelCard, ModelCard,
ModelDefaultGenerationSettings,
ModelList, ModelList,
ModelLoadRequest, ModelLoadRequest,
ModelLoadResponse, ModelLoadResponse,
ModelPropsResponse,
) )
from endpoints.core.types.health import HealthCheckResponse from endpoints.core.types.health import HealthCheckResponse
from endpoints.core.types.sampler_overrides import ( from endpoints.core.types.sampler_overrides import (
@@ -65,6 +68,34 @@ async def healthcheck(response: Response) -> HealthCheckResponse:
) )
@router.get("/.well-known/serviceinfo")
async def service_info():
return JSONResponse(
content={
"version": 0.1,
"software": {
"name": "TabbyAPI",
"repository": "https://github.com/theroyallab/tabbyAPI",
"homepage": "https://github.com/theroyallab/tabbyAPI",
},
"api": {
"openai": {
"name": "OpenAI API",
"relative_url": "/v1",
"documentation": "https://theroyallab.github.io/tabbyAPI",
"version": 1,
},
"koboldai": {
"name": "KoboldAI API",
"relative_url": "/api",
"documentation": "https://theroyallab.github.io/tabbyAPI",
"version": 1,
},
},
}
)
# Model list endpoint # Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)]) @router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) @router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
@@ -102,6 +133,30 @@ async def current_model() -> ModelCard:
return get_current_model() return get_current_model()
@router.get(
"/props", dependencies=[Depends(check_api_key), Depends(check_model_container)]
)
async def model_props() -> ModelPropsResponse:
"""
Returns specific properties of a model for clients.
To get all properties, use /v1/model instead.
"""
current_model_card = get_current_model()
resp = ModelPropsResponse(
total_slots=current_model_card.parameters.max_batch_size,
default_generation_settings=ModelDefaultGenerationSettings(
n_ctx=current_model_card.parameters.max_seq_len,
),
)
if current_model_card.parameters.prompt_template_content:
resp.chat_template = current_model_card.parameters.prompt_template_content
return resp
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) @router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models(request: Request) -> ModelList: async def list_draft_models(request: Request) -> ModelList:
""" """

View File

@@ -16,10 +16,12 @@ class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = None max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0 rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0
max_batch_size: Optional[int] = 1
cache_size: Optional[int] = None cache_size: Optional[int] = None
cache_mode: Optional[str] = "FP16" cache_mode: Optional[str] = "FP16"
chunk_size: Optional[int] = 2048 chunk_size: Optional[int] = 2048
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
prompt_template_content: Optional[str] = None
num_experts_per_token: Optional[int] = None num_experts_per_token: Optional[int] = None
use_vision: Optional[bool] = False use_vision: Optional[bool] = False
@@ -139,3 +141,17 @@ class ModelLoadResponse(BaseModel):
module: int module: int
modules: int modules: int
status: str status: str
class ModelDefaultGenerationSettings(BaseModel):
"""Contains default generation settings for model props."""
n_ctx: int
class ModelPropsResponse(BaseModel):
"""Represents a model props response."""
total_slots: int = 1
chat_template: str = ""
default_generation_settings: ModelDefaultGenerationSettings

View File

@@ -16,30 +16,30 @@ version = "0.0.1"
description = "An OAI compatible exllamav2 API that's both lightweight and fast" description = "An OAI compatible exllamav2 API that's both lightweight and fast"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"fastapi-slim >= 0.110.0", "fastapi-slim >= 0.115",
"pydantic >= 2.0.0", "pydantic >= 2.0.0",
"ruamel.yaml", "ruamel.yaml",
"rich", "rich",
"uvicorn >= 0.28.1", "uvicorn >= 0.28.1",
"jinja2 >= 3.0.0", "jinja2 >= 3.0.0",
"loguru", "loguru",
"sse-starlette", "sse-starlette >= 2.2.0",
"packaging", "packaging",
"tokenizers", "tokenizers >= 0.21.0",
"lm-format-enforcer >= 0.9.6", "formatron >= 0.4.11",
"kbnf >= 0.4.1",
"aiofiles", "aiofiles",
"aiohttp", "aiohttp",
"async_lru", "async_lru",
"huggingface_hub", "huggingface_hub",
"psutil", "psutil",
"httptools>=0.5.0", "httptools >= 0.5.0",
"pillow", "pillow",
# Improved asyncio loops # Improved asyncio loops
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'", "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",
"winloop ; platform_system == 'Windows'", "winloop ; platform_system == 'Windows'",
# TEMP: Remove once 2.x is fixed in upstream
"numpy < 2.0.0", "numpy < 2.0.0",
# For python 3.12 # For python 3.12
@@ -53,7 +53,6 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
extras = [ extras = [
# Heavy dependencies that aren't for everyday use # Heavy dependencies that aren't for everyday use
"outlines",
"infinity-emb", "infinity-emb",
"sentence-transformers", "sentence-transformers",
] ]
@@ -62,68 +61,46 @@ dev = [
] ]
cu121 = [ cu121 = [
# Torch (Extra index URLs not support in pyproject.toml) # Torch (Extra index URLs not support in pyproject.toml)
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", "torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", "torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", "torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"torch @ https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "torch @ https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2 # Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu121.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+cu121.torch2.5.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases # Windows FA2 from https://github.com/bdashore3/flash-attention/releases
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu124torch2.5.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu124torch2.5.1cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu124torch2.5.1cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.5cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
]
cu118 = [
# Torch
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"torch @ https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+cu118.torch2.4.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
] ]
amd = [ amd = [
# Torch triton for ROCm # Torch triton for ROCm
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.1.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.1.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.0.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", "pytorch_triton_rocm @ https://download.pytorch.org/whl/pytorch_triton_rocm-3.1.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
# Torch # Torch
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.4.1%2Brocm6.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", "torch @ https://download.pytorch.org/whl/rocm6.2/torch-2.5.1%2Brocm6.2-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.4.1%2Brocm6.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", "torch @ https://download.pytorch.org/whl/rocm6.2/torch-2.5.1%2Brocm6.2-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.4.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", "torch @ https://download.pytorch.org/whl/rocm6.2/torch-2.5.1%2Brocm6.2-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
# Exl2 # Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+rocm6.1.torch2.4.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+rocm6.2.torch2.5.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+rocm6.1.torch2.4.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+rocm6.2.torch2.5.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.4/exllamav2-0.2.4+rocm6.1.torch2.4.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.7/exllamav2-0.2.7+rocm6.2.torch2.5.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
] ]
# MARK: Ruff options # MARK: Ruff options

View File

@@ -47,7 +47,7 @@ def get_install_features(lib_name: str = None):
# Ask the user for the GPU lib # Ask the user for the GPU lib
gpu_lib_choices = { gpu_lib_choices = {
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, "B": {"pretty": "NVIDIA Cuda 11.8 (Unsupported)", "internal": "cu118"},
"C": {"pretty": "AMD", "internal": "amd"}, "C": {"pretty": "AMD", "internal": "amd"},
} }
user_input = get_user_choice( user_input = get_user_choice(

View File

@@ -6,7 +6,7 @@
{%- set start_header = "<|start_header_id|>" -%} {%- set start_header = "<|start_header_id|>" -%}
{%- set end_header = "<|end_header_id|>\n" -%} {%- set end_header = "<|end_header_id|>\n" -%}
{%- set example_tool_call -%}[ {%- set example_tool_call = '[
{ {
"id": "tool_id_1342", "id": "tool_id_1342",
"function": { "function": {
@@ -23,29 +23,26 @@
}, },
"type": "function" "type": "function"
} }
] ]' -%}
{%- endset -%}
{%- set inital_system_prompt -%}You are an assistant that has access to the following set of tools, to call a tool: {%- set inital_system_prompt = 'You are an assistant that has access to the following set of tools, to call a tool:
1. Prefix calls with '{{ tool_start }}' and end calls with '{{ tool_end }}' 1. Prefix calls with ' + tool_start + ' and end calls with ' + tool_end + '
2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be. 2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be.
3. Generate all calls using the following json tool call format. Here is a multi tool call example: 3. Generate all calls using the following json tool call format. Here is a multi tool call example:
{{ tool_start }}{{ example_tool_call }}{{ tool_end }} ' + tool_start + example_tool_call + tool_end + '
Here are the tools available for you to call: Here are the tools available for you to call:
{{ tools_json }} ' + tools_json -%}
{%- endset -%}
{%- set tool_reminder -%}Available Tools: {%- set tool_reminder = 'Available Tools:
{{ tools_json }} ' + tools_json + '
Tool Call Format Example: Tool Call Format Example:
{{ tool_start }}{{ example_tool_call }} ' + tool_start + example_tool_call + '
Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}. Prefix & Suffix: Begin tool calls with ' + tool_start + ' and end with ' + tool_end + '.
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without). Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).' -%}
{%- endset -%}
{# Template #} {# Template #}
@@ -54,15 +51,15 @@ Argument Types: Use correct data types for arguments (e.g., strings in quotes, n
{%- if role not in message_roles -%} {%- if role not in message_roles -%}
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }} {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }}
{%- endif -%} {%- endif -%}
{%- set content = message['content'] | default('', true) | trim -%} {%- set content = message['content'] if message['content'] is defined else '' | trim -%}
{%- if loop.first -%} {%- if loop.first -%}
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }} {{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
{{ inital_system_prompt }} {{ inital_system_prompt }}
{{ content }}{{ eos_token }} {{ content }}{{ eos_token }}
{%- endif -%} {%- endif -%}
{%- if not loop.first -%} {%- if not loop.first -%}
{{ start_header }}{{ role }}{{ end_header }} {{ start_header }}{{ role }}{{ end_header }}
{{ content }} {{ content }}
@@ -81,4 +78,4 @@ Argument Types: Use correct data types for arguments (e.g., strings in quotes, n
{{ tool_precursor }}{{ tool_start }} {{ tool_precursor }}{{ tool_start }}
{%- else -%} {%- else -%}
{{ start_header }}assistant{{ end_header }} {{ start_header }}assistant{{ end_header }}
{%- endif -%} {%- endif -%}