mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
feat: workflows for formatting/linting (#35)
* add github workflows for pylint and yapf * yapf * docstrings for auth * fix auth.py * fix generators.py * fix gen_logging.py * fix main.py * fix model.py * fix templating.py * fix utils.py * update formatting.sh to include subdirs for pylint * fix model_test.py * fix wheel_test.py * rename utils to utils_oai * fix OAI/utils_oai.py * fix completion.py * fix token.py * fix lora.py * fix common.py * add pylintrc and fix model.py * finish up pylint * fix attribute error * main.py formatting * add formatting batch script * Main: Remove unnecessary global Linter suggestion. Signed-off-by: kingbri <bdashore3@proton.me> * switch to ruff * Formatting + Linting: Add ruff.toml Signed-off-by: kingbri <bdashore3@proton.me> * Formatting + Linting: Switch scripts to use ruff Also remove the file and recent file change functions from both scripts. Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format and lint Signed-off-by: kingbri <bdashore3@proton.me> * Scripts + Workflows: Format Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Remove pylint flags We use ruff now Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format Signed-off-by: kingbri <bdashore3@proton.me> * Formatting: Line length is 88 Use the same value as Black. Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format Update to new line length rules. Signed-off-by: kingbri <bdashore3@proton.me> --------- Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Co-authored-by: kingbri <bdashore3@proton.me>
This commit is contained in:
399
model.py
399
model.py
@@ -1,29 +1,36 @@
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
import gc
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
import torch
|
||||
from exllamav2 import(
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Tokenizer,
|
||||
ExLlamaV2Lora
|
||||
)
|
||||
from exllamav2.generator import(
|
||||
ExLlamaV2StreamingGenerator,
|
||||
ExLlamaV2Sampler
|
||||
ExLlamaV2Lora,
|
||||
)
|
||||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||
|
||||
from gen_logging import log_generation_params, log_prompt, log_response
|
||||
from typing import List, Optional, Union
|
||||
from templating import PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file
|
||||
from templating import (
|
||||
PromptTemplate,
|
||||
find_template_from_model,
|
||||
get_template_from_model_json,
|
||||
get_template_from_file,
|
||||
)
|
||||
from utils import coalesce, unwrap
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
|
||||
|
||||
|
||||
class ModelContainer:
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
model: Optional[ExLlamaV2] = None
|
||||
@@ -40,35 +47,51 @@ class ModelContainer:
|
||||
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
|
||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
"""
|
||||
Create model container
|
||||
|
||||
Args:
|
||||
model_dir (int): Model directory containing config.json, tokenizer.model etc.
|
||||
model_dir (int): Model directory containing config.json,
|
||||
tokenizer.model etc.
|
||||
quiet (bool): Suppress console output
|
||||
load_progress_callback (function, optional): A function to call for each module loaded. Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
|
||||
load_progress_callback (function, optional): A function to call for
|
||||
each module loaded. Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int,
|
||||
loading_draft: bool)
|
||||
**kwargs:
|
||||
`cache_mode` (str): Sets cache mode, "FP16" or "FP8" (defaulf: "FP16")
|
||||
'max_seq_len' (int): Override model's default max sequence length (default: 4096)
|
||||
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
|
||||
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
|
||||
'prompt_template' (str): Manually sets the prompt template for this model (default: None)
|
||||
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
|
||||
Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller
|
||||
batches. This limits the size of temporary buffers needed for the hidden state and attention
|
||||
weights.
|
||||
`cache_mode` (str): Sets cache mode, "FP16" or "FP8"
|
||||
(defaulf: "FP16")
|
||||
'max_seq_len' (int): Override model's default max sequence
|
||||
length (default: 4096)
|
||||
'rope_scale' (float): Set RoPE scaling factor for model
|
||||
(default: 1.0)
|
||||
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
|
||||
(default: 1.0)
|
||||
'prompt_template' (str): Manually sets the prompt template for
|
||||
this model (default: None)
|
||||
'chunk_size' (int): Sets the maximum chunk size for the model
|
||||
(default: 2048)
|
||||
Inferencing in chunks reduces overall VRAM overhead by
|
||||
processing very long sequences in smaller batches. This
|
||||
limits the size of temporary buffers needed for the hidden
|
||||
state and attention weights.
|
||||
'draft_model_dir' (str): Draft model directory
|
||||
'draft_rope_scale' (float): Set RoPE scaling factor for draft model (default: 1.0)
|
||||
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
|
||||
By default, the draft model's alpha value is calculated automatically to scale to the size of the
|
||||
'draft_rope_scale' (float): Set RoPE scaling factor for draft
|
||||
model (default: 1.0)
|
||||
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
|
||||
model. By default, the draft model's alpha value is
|
||||
calculated automatically to scale to the size of the
|
||||
full model.
|
||||
'lora_dir' (str): Lora directory
|
||||
'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling'
|
||||
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
|
||||
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
|
||||
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
|
||||
'lora_dir' (str): LoRA directory
|
||||
'loras' (list[dict]): List of loras to be loaded, consisting of
|
||||
'name' and 'scaling'
|
||||
'gpu_split_auto' (bool): Automatically split model across
|
||||
available devices (default: True)
|
||||
'gpu_split' (list[float]): Allocation for weights and (some)
|
||||
tensors, per device
|
||||
'no_flash_attn' (bool): Turns off flash attention
|
||||
(increases vram usage) (default: False)
|
||||
"""
|
||||
|
||||
self.quiet = quiet
|
||||
@@ -90,7 +113,8 @@ class ModelContainer:
|
||||
if override_base_seq_len:
|
||||
self.config.max_seq_len = override_base_seq_len
|
||||
|
||||
# Grab the base model's sequence length before overrides for rope calculations
|
||||
# Grab the base model's sequence length before overrides for
|
||||
# rope calculations
|
||||
base_seq_len = self.config.max_seq_len
|
||||
|
||||
# Set the target seq len if present
|
||||
@@ -103,14 +127,14 @@ class ModelContainer:
|
||||
|
||||
# Automatically calculate rope alpha
|
||||
self.config.scale_alpha_value = unwrap(
|
||||
kwargs.get("rope_alpha"),
|
||||
self.calculate_rope_alpha(base_seq_len)
|
||||
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
||||
)
|
||||
|
||||
# Turn off flash attention?
|
||||
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
|
||||
|
||||
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
|
||||
# low_mem is currently broken in exllamav2. Don't use it until it's
|
||||
# fixed.
|
||||
"""
|
||||
if "low_mem" in kwargs and kwargs["low_mem"]:
|
||||
self.config.set_low_mem()
|
||||
@@ -119,7 +143,10 @@ class ModelContainer:
|
||||
# Set prompt template override if provided
|
||||
prompt_template_name = kwargs.get("prompt_template")
|
||||
if prompt_template_name:
|
||||
print(f"Attempting to load prompt template with name {prompt_template_name}")
|
||||
print(
|
||||
"Attempting to load prompt template with name",
|
||||
{prompt_template_name},
|
||||
)
|
||||
# Read the template
|
||||
self.prompt_template = get_template_from_file(prompt_template_name)
|
||||
else:
|
||||
@@ -127,16 +154,17 @@ class ModelContainer:
|
||||
self.prompt_template = get_template_from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
"from_tokenizer_config"
|
||||
"from_tokenizer_config",
|
||||
)
|
||||
|
||||
# Try finding the chat template from the model's config.json
|
||||
# TODO: This may not even be used with huggingface models, mark for removal.
|
||||
# TODO: This may not even be used with huggingface models,
|
||||
# mark for removal.
|
||||
if self.prompt_template is None:
|
||||
self.prompt_template = get_template_from_model_json(
|
||||
pathlib.Path(self.config.model_config),
|
||||
"chat_template",
|
||||
"from_model_config"
|
||||
"from_model_config",
|
||||
)
|
||||
|
||||
# If that fails, attempt fetching from model name
|
||||
@@ -147,10 +175,13 @@ class ModelContainer:
|
||||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
print(f"Using template {self.prompt_template.name} for chat completions.")
|
||||
print(
|
||||
f"Using template {self.prompt_template.name} for chat " "completions."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Chat completions are disabled because a prompt template wasn't provided or auto-detected."
|
||||
"Chat completions are disabled because a prompt template",
|
||||
"wasn't provided or auto-detected.",
|
||||
)
|
||||
|
||||
# Set num of experts per token if provided
|
||||
@@ -159,11 +190,16 @@ class ModelContainer:
|
||||
if hasattr(self.config, "num_experts_per_token"):
|
||||
self.config.num_experts_per_token = num_experts_override
|
||||
else:
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support overriding MoE experts")
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not "
|
||||
"support overriding MoE experts"
|
||||
)
|
||||
|
||||
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
|
||||
chunk_size = min(
|
||||
unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len
|
||||
)
|
||||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attn_size = chunk_size ** 2
|
||||
self.config.max_attn_size = chunk_size**2
|
||||
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
@@ -171,47 +207,63 @@ class ModelContainer:
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
|
||||
print(
|
||||
"A draft config was found but a model name was not given. "
|
||||
"Please check your config.yml! Skipping draft load."
|
||||
)
|
||||
enable_draft = False
|
||||
|
||||
if enable_draft:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
)
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
|
||||
self.draft_config.scale_pos_emb = unwrap(
|
||||
draft_args.get("draft_rope_scale"), 1.0
|
||||
)
|
||||
|
||||
# Automatically calculate draft rope alpha
|
||||
self.draft_config.scale_alpha_value = unwrap(
|
||||
draft_args.get("draft_rope_alpha"),
|
||||
self.calculate_rope_alpha(self.draft_config.max_seq_len)
|
||||
self.calculate_rope_alpha(self.draft_config.max_seq_len),
|
||||
)
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
if "chunk_size" in kwargs:
|
||||
self.draft_config.max_input_len = kwargs["chunk_size"]
|
||||
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
|
||||
|
||||
def calculate_rope_alpha(self, base_seq_len):
|
||||
"""Calculate the rope alpha value for a given sequence length."""
|
||||
ratio = self.config.max_seq_len / base_seq_len
|
||||
|
||||
# Default to a 1 alpha if the sequence length is ever less than or equal to 1
|
||||
alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
||||
# Default to a 1 alpha if the sequence length is ever less
|
||||
# than or equal to 1
|
||||
if ratio <= 1.0:
|
||||
alpha = 1
|
||||
else:
|
||||
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
|
||||
return alpha
|
||||
|
||||
def get_model_path(self, is_draft: bool = False):
|
||||
model_path = pathlib.Path(self.draft_config.model_dir if is_draft else self.config.model_dir)
|
||||
"""Get the path for this model."""
|
||||
model_path = pathlib.Path(
|
||||
self.draft_config.model_dir if is_draft else self.config.model_dir
|
||||
)
|
||||
return model_path
|
||||
|
||||
def load(self, progress_callback = None):
|
||||
def load(self, progress_callback=None):
|
||||
"""
|
||||
Load model
|
||||
|
||||
Args:
|
||||
progress_callback (function, optional): A function to call for each module loaded. Prototype:
|
||||
progress_callback (function, optional): A function to call for each
|
||||
module loaded. Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int)
|
||||
"""
|
||||
for _ in self.load_gen(progress_callback):
|
||||
@@ -231,25 +283,32 @@ class ModelContainer:
|
||||
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
||||
|
||||
if lora_name is None:
|
||||
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
|
||||
print(
|
||||
"One of your loras does not have a name. Please check your "
|
||||
"config.yml! Skipping lora load."
|
||||
)
|
||||
failure.append(lora_name)
|
||||
continue
|
||||
|
||||
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||
lora_path = lora_directory / lora_name
|
||||
self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling))
|
||||
# FIXME(alpin): Does self.model need to be passed here?
|
||||
self.active_loras.append(
|
||||
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
||||
)
|
||||
print("Lora successfully loaded.")
|
||||
success.append(lora_name)
|
||||
|
||||
# Return success and failure names
|
||||
return { 'success': success, 'failure': failure }
|
||||
return {"success": success, "failure": failure}
|
||||
|
||||
def load_gen(self, progress_callback = None):
|
||||
def load_gen(self, progress_callback=None):
|
||||
"""
|
||||
Load model, generator function
|
||||
|
||||
Args:
|
||||
progress_callback (function, optional): A function to call for each module loaded. Prototype:
|
||||
progress_callback (function, optional): A function to call for each
|
||||
module loaded. Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int)
|
||||
"""
|
||||
|
||||
@@ -262,13 +321,18 @@ class ModelContainer:
|
||||
if not self.quiet:
|
||||
print("Loading draft model: " + self.draft_config.model_dir)
|
||||
|
||||
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy = True)
|
||||
reserve = [auto_split_reserve_bytes] + [0] * 16
|
||||
yield from self.draft_model.load_autosplit_gen(self.draft_cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback)
|
||||
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
|
||||
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
||||
yield from self.draft_model.load_autosplit_gen(
|
||||
self.draft_cache,
|
||||
reserve_vram=reserve,
|
||||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
)
|
||||
|
||||
# Test VRAM allocation with a full-length forward pass
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long)
|
||||
self.draft_model.forward(input_ids, cache = self.cache, preprocess_only = True)
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
# Load model
|
||||
self.model = ExLlamaV2(self.config)
|
||||
@@ -276,29 +340,41 @@ class ModelContainer:
|
||||
print("Loading model: " + self.config.model_dir)
|
||||
|
||||
if not self.gpu_split_auto:
|
||||
for value in self.model.load_gen(self.gpu_split, callback_gen = progress_callback):
|
||||
for value in self.model.load_gen(
|
||||
self.gpu_split, callback_gen=progress_callback
|
||||
):
|
||||
if isinstance(value, str):
|
||||
yield value
|
||||
|
||||
if self.cache_fp8:
|
||||
self.cache = ExLlamaV2Cache_8bit(self.model, lazy = self.gpu_split_auto)
|
||||
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
|
||||
else:
|
||||
self.cache = ExLlamaV2Cache(self.model, lazy = self.gpu_split_auto)
|
||||
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
|
||||
|
||||
if self.gpu_split_auto:
|
||||
reserve = [auto_split_reserve_bytes] + [0] * 16
|
||||
yield from self.model.load_autosplit_gen(self.cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback)
|
||||
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
||||
yield from self.model.load_autosplit_gen(
|
||||
self.cache,
|
||||
reserve_vram=reserve,
|
||||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
)
|
||||
|
||||
# Test VRAM allocation with a full-length forward pass
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long)
|
||||
self.model.forward(input_ids, cache = self.cache, preprocess_only = True)
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
# Create generator
|
||||
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer, self.draft_model, self.draft_cache)
|
||||
self.generator = ExLlamaV2StreamingGenerator(
|
||||
self.model,
|
||||
self.cache,
|
||||
self.tokenizer,
|
||||
self.draft_model,
|
||||
self.draft_cache,
|
||||
)
|
||||
|
||||
print("Model successfully loaded.")
|
||||
|
||||
|
||||
def unload(self, loras_only: bool = False):
|
||||
"""
|
||||
Free all VRAM resources used by this model
|
||||
@@ -327,19 +403,24 @@ class ModelContainer:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Common function for token operations
|
||||
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
|
||||
"""Common function for token operations"""
|
||||
if text:
|
||||
# Assume token encoding
|
||||
return self.tokenizer.encode(
|
||||
text,
|
||||
add_bos = unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True)
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
)
|
||||
if ids:
|
||||
# Assume token decoding
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
|
||||
return self.tokenizer.decode(
|
||||
ids,
|
||||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||
)[0]
|
||||
|
||||
return None
|
||||
|
||||
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
|
||||
return {
|
||||
@@ -350,13 +431,15 @@ class ModelContainer:
|
||||
}
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
generation = list(self.generate_gen(prompt, **kwargs))
|
||||
if generation:
|
||||
response = "".join(map(lambda chunk: chunk[0], generation))
|
||||
return response, generation[-1][1], generation[-1][2]
|
||||
else:
|
||||
return "", 0, 0
|
||||
|
||||
return "", 0, 0
|
||||
|
||||
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
||||
def generate_gen(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Create generator function for prompt completion
|
||||
@@ -366,7 +449,8 @@ class ModelContainer:
|
||||
**kwargs:
|
||||
'token_healing' (bool): Use token healing (default: False)
|
||||
'temperature' (float): Sampling temperature (default: 1.0)
|
||||
'temperature_last' (bool): Apply temperature after all other samplers (default: False)
|
||||
'temperature_last' (bool): Apply temperature after all other
|
||||
samplers (default: False)
|
||||
'top_k' (int): Sampling top-K (default: 0)
|
||||
'top_p' (float): Sampling top-P (default: 1.0)
|
||||
'min_p' (float): Sampling min-P (default: 0.0)
|
||||
@@ -375,19 +459,27 @@ class ModelContainer:
|
||||
'mirostat' (bool): Use Mirostat (default: False)
|
||||
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
|
||||
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
|
||||
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
||||
'repetition_range' (int): Repetition penalty range (default: whole context)
|
||||
'repetition_decay' (int): Repetition penalty range (default: same as range)
|
||||
'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS])
|
||||
'repetition_penalty' (float): Token repetition/presence penalty
|
||||
(default: 1.15)
|
||||
'repetition_range' (int): Repetition penalty range
|
||||
(default: whole context)
|
||||
'repetition_decay' (int): Repetition penalty range
|
||||
(default: same as range)
|
||||
'stop' (List[Union[str, int]]): List of stop strings/tokens to
|
||||
end response (default: [EOS])
|
||||
'max_tokens' (int): Max no. tokens in response (default: 150)
|
||||
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
|
||||
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
|
||||
'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None)
|
||||
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
|
||||
'generate_window' (int): Space to reserve at the end of the model's context when generating.
|
||||
Rolls context window by the same amount if context length is exceeded to allow generating past
|
||||
the models max_seq_len.
|
||||
|
||||
'add_bos_token' (bool): Adds the BOS token to the start of the
|
||||
prompt (default: True)
|
||||
'ban_eos_token' (bool): Bans the EOS token from generation
|
||||
(default: False)
|
||||
'logit_bias' (Dict[int, float]): Biases specific tokens to
|
||||
either show up more or less (default: None)
|
||||
'stream_interval' (float): Interval in seconds between each
|
||||
output chunk (default: immediate)
|
||||
'generate_window' (int): Space to reserve at the end of the
|
||||
model's context when generating. Rolls context window by
|
||||
the same amount if context length is exceeded to allow
|
||||
generating pastthe models max_seq_len.
|
||||
"""
|
||||
|
||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||
@@ -399,17 +491,37 @@ class ModelContainer:
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
|
||||
# Warn of unsupported settings if the setting is enabled
|
||||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
|
||||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
|
||||
gen_settings, "mirostat"
|
||||
):
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||
"Mirostat sampling"
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
|
||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
gen_settings, "min_p"
|
||||
):
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not "
|
||||
"support min-P sampling"
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
|
||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
gen_settings, "tfs"
|
||||
):
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||
"tail-free sampling (TFS)"
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
|
||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
|
||||
gen_settings, "temperature_last"
|
||||
):
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||
"temperature_last"
|
||||
)
|
||||
|
||||
# Apply settings
|
||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||
@@ -424,14 +536,24 @@ class ModelContainer:
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
||||
gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0)
|
||||
gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len)
|
||||
gen_settings.token_repetition_penalty = unwrap(
|
||||
kwargs.get("repetition_penalty"), 1.0
|
||||
)
|
||||
gen_settings.token_repetition_range = unwrap(
|
||||
kwargs.get("repetition_range"), self.config.max_seq_len
|
||||
)
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed fallback
|
||||
# It's technically fine to use -1, but this just validates the passed
|
||||
# fallback
|
||||
# Always default to 0 if something goes wrong
|
||||
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
|
||||
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
|
||||
if gen_settings.token_repetition_range <= 0:
|
||||
fallback_decay = 0
|
||||
else:
|
||||
fallback_decay = gen_settings.token_repetition_range
|
||||
gen_settings.token_repetition_decay = coalesce(
|
||||
kwargs.get("repetition_decay"), fallback_decay, 0
|
||||
)
|
||||
|
||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||
@@ -448,13 +570,13 @@ class ModelContainer:
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
max_tokens = max_tokens,
|
||||
max_tokens=max_tokens,
|
||||
**vars(gen_settings),
|
||||
token_healing = token_healing,
|
||||
add_bos_token = add_bos_token,
|
||||
ban_eos_token = ban_eos_token,
|
||||
stop_conditions = stop_conditions,
|
||||
logit_bias = logit_bias
|
||||
token_healing=token_healing,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
stop_conditions=stop_conditions,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
|
||||
# Log prompt to console
|
||||
@@ -465,13 +587,17 @@ class ModelContainer:
|
||||
# Create a vocab tensor if it doesn't exist for token biasing
|
||||
if gen_settings.token_bias is None:
|
||||
padding = -self.tokenizer.config.vocab_size % 32
|
||||
gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float)
|
||||
gen_settings.token_bias = torch.zeros(
|
||||
(self.tokenizer.config.vocab_size + padding,),
|
||||
dtype=torch.float,
|
||||
)
|
||||
|
||||
# Map logits to the tensor with their biases
|
||||
for token, bias in logit_bias.items():
|
||||
gen_settings.token_bias[token] = bias
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||
# Ban the EOS token if specified. If not, append to stop conditions
|
||||
# as well.
|
||||
# Set this below logging to avoid polluting the stop strings array
|
||||
if ban_eos_token:
|
||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||
@@ -483,16 +609,15 @@ class ModelContainer:
|
||||
|
||||
# Tokenized context
|
||||
ids = self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos = add_bos_token,
|
||||
encode_special_tokens = True
|
||||
prompt, add_bos=add_bos_token, encode_special_tokens=True
|
||||
)
|
||||
context_len = len(ids[0])
|
||||
|
||||
if context_len > self.config.max_seq_len:
|
||||
print(
|
||||
f"WARNING: The context length {context_len} is greater than the max_seq_len {self.config.max_seq_len}.",
|
||||
"Generation is truncated and metrics may not be accurate."
|
||||
f"WARNING: The context length {context_len} is greater than "
|
||||
f"the max_seq_len {self.config.max_seq_len}.",
|
||||
"Generation is truncated and metrics may not be accurate.",
|
||||
)
|
||||
|
||||
prompt_tokens = ids.shape[-1]
|
||||
@@ -503,26 +628,32 @@ class ModelContainer:
|
||||
start_time = time.time()
|
||||
last_chunk_time = start_time
|
||||
|
||||
save_tokens = torch.empty((1, 0), dtype = torch.bool)
|
||||
save_tokens = torch.empty((1, 0), dtype=torch.bool)
|
||||
chunk_buffer = ""
|
||||
chunk_tokens = 0
|
||||
|
||||
while True:
|
||||
# Ingest prompt
|
||||
if chunk_tokens == 0:
|
||||
ids = torch.cat((ids, save_tokens), dim = - 1)
|
||||
save_tokens = torch.empty((1, 0), dtype = torch.bool)
|
||||
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||
save_tokens = torch.empty((1, 0), dtype=torch.bool)
|
||||
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
|
||||
active_ids = ids[:, max(0, overflow):]
|
||||
active_ids = ids[:, max(0, overflow) :]
|
||||
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
||||
|
||||
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras)
|
||||
self.generator.begin_stream(
|
||||
active_ids,
|
||||
gen_settings,
|
||||
token_healing=token_healing,
|
||||
loras=self.active_loras,
|
||||
)
|
||||
|
||||
# Generate
|
||||
chunk, eos, tokens = self.generator.stream()
|
||||
|
||||
if token_healing:
|
||||
ids[:, -1] = self.generator.sequence_ids[:, -2] # Extract healed token
|
||||
# Extract healed token
|
||||
ids[:, -1] = self.generator.sequence_ids[:, -2]
|
||||
token_healing = False
|
||||
|
||||
save_tokens = torch.cat((save_tokens, tokens), dim=-1)
|
||||
@@ -535,7 +666,9 @@ class ModelContainer:
|
||||
now = time.time()
|
||||
elapsed = now - last_chunk_time
|
||||
|
||||
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
|
||||
if chunk_buffer != "" and (
|
||||
elapsed > stream_interval or eos or generated_tokens == max_tokens
|
||||
):
|
||||
yield chunk_buffer, prompt_tokens, generated_tokens
|
||||
full_response += chunk_buffer
|
||||
chunk_buffer = ""
|
||||
@@ -549,12 +682,20 @@ class ModelContainer:
|
||||
|
||||
elapsed_time = last_chunk_time - start_time
|
||||
|
||||
initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
|
||||
initial_response = (
|
||||
f"Metrics: {generated_tokens} tokens generated in "
|
||||
f"{round(elapsed_time, 2)} seconds"
|
||||
)
|
||||
itemization = []
|
||||
extra_parts = []
|
||||
|
||||
# Add tokens per second
|
||||
itemization.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s")
|
||||
tokens_per_second = (
|
||||
"Indeterminate"
|
||||
if elapsed_time == 0
|
||||
else round(generated_tokens / elapsed_time, 2)
|
||||
)
|
||||
itemization.append(f"{tokens_per_second} T/s")
|
||||
|
||||
# Add context (original token count)
|
||||
if ids is not None:
|
||||
@@ -564,4 +705,10 @@ class ModelContainer:
|
||||
extra_parts.append("<-- Not accurate (truncated)")
|
||||
|
||||
# Print output
|
||||
print(initial_response + " (" + ", ".join(itemization) + ") " + " ".join(extra_parts))
|
||||
print(
|
||||
initial_response
|
||||
+ " ("
|
||||
+ ", ".join(itemization)
|
||||
+ ") "
|
||||
+ " ".join(extra_parts)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user