mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Model: Remove and format comments
The comment in __init__ was outdated and all the kwargs are the config options anyways. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -101,51 +101,9 @@ class ExllamaV2Container:
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
"""
|
||||
Create model container
|
||||
Primary initializer for model container.
|
||||
|
||||
Args:
|
||||
model_dir (int): Model directory containing config.json,
|
||||
tokenizer.model etc.
|
||||
quiet (bool): Suppress console output
|
||||
load_progress_callback (function, optional): A function to call for
|
||||
each module loaded. Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int,
|
||||
loading_draft: bool)
|
||||
**kwargs:
|
||||
`cache_mode` (str): Sets cache mode: "FP16"/"Q8"/"Q6"/"Q4"
|
||||
(default: "FP16")
|
||||
'max_seq_len' (int): Override model's default max sequence
|
||||
length (default: 4096)
|
||||
'cache_size' (int): Num of tokens to allocate space for in the k/v cache
|
||||
(default: max_seq_len)
|
||||
'rope_scale' (float): Set RoPE scaling factor for model
|
||||
(default: 1.0)
|
||||
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
|
||||
(default: 1.0)
|
||||
'prompt_template' (str): Manually sets the prompt template for
|
||||
this model (default: None)
|
||||
'chunk_size' (int): Sets the maximum chunk size for the model
|
||||
(default: 2048)
|
||||
Inferencing in chunks reduces overall VRAM overhead by
|
||||
processing very long sequences in smaller batches. This
|
||||
limits the size of temporary buffers needed for the hidden
|
||||
state and attention weights.
|
||||
'draft_model_dir' (str): Draft model directory
|
||||
'draft_rope_scale' (float): Set RoPE scaling factor for draft
|
||||
model (default: 1.0)
|
||||
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
|
||||
model. By default, the draft model's alpha value is
|
||||
calculated automatically to scale to the size of the
|
||||
full model.
|
||||
'draft_cache_mode' (str): Sets draft cache mode: "FP16"/"Q8"/"Q6"/"Q4"
|
||||
(default: "FP16")
|
||||
'lora_dir' (str): LoRA directory
|
||||
'loras' (list[dict]): List of loras to be loaded, consisting of
|
||||
'name' and 'scaling'
|
||||
'gpu_split_auto' (bool): Automatically split model across
|
||||
available devices (default: True)
|
||||
'gpu_split' (list[float]): Allocation for weights and (some)
|
||||
tensors, per device
|
||||
Kwargs are located in config_sample.yml
|
||||
"""
|
||||
|
||||
self.quiet = quiet
|
||||
@@ -386,7 +344,7 @@ class ExllamaV2Container:
|
||||
self.draft_config.max_attention_size = chunk_size**2
|
||||
|
||||
def find_prompt_template(self, prompt_template_name, model_directory):
|
||||
"""Tries to find a prompt template using various methods"""
|
||||
"""Tries to find a prompt template using various methods."""
|
||||
|
||||
logger.info("Attempting to load a prompt template if present.")
|
||||
|
||||
@@ -428,6 +386,7 @@ class ExllamaV2Container:
|
||||
|
||||
def calculate_rope_alpha(self, base_seq_len):
|
||||
"""Calculate the rope alpha value for a given sequence length."""
|
||||
|
||||
ratio = self.config.max_seq_len / base_seq_len
|
||||
|
||||
# Default to a 1 alpha if the sequence length is ever less
|
||||
@@ -504,7 +463,9 @@ class ExllamaV2Container:
|
||||
|
||||
Args:
|
||||
progress_callback (function, optional): A function to call for each
|
||||
module loaded. Prototype:
|
||||
module loaded.
|
||||
|
||||
Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int)
|
||||
"""
|
||||
|
||||
@@ -549,11 +510,13 @@ class ExllamaV2Container:
|
||||
@torch.inference_mode()
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
"""
|
||||
Load model, generator function
|
||||
Synchronous generator for loading.
|
||||
|
||||
Args:
|
||||
progress_callback (function, optional): A function to call for each
|
||||
module loaded. Prototype:
|
||||
module loaded.
|
||||
|
||||
Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int)
|
||||
|
||||
Runs under a shared inference mode context.
|
||||
@@ -695,6 +658,8 @@ class ExllamaV2Container:
|
||||
)
|
||||
|
||||
async def create_generator(self):
|
||||
"""Create and save a Exllama generator class."""
|
||||
|
||||
try:
|
||||
# Don't acquire locks unless a model is loaded
|
||||
if self.model_loaded:
|
||||
@@ -728,9 +693,7 @@ class ExllamaV2Container:
|
||||
return unwrap(self.generator.generator.current_loras, [])
|
||||
|
||||
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Load loras
|
||||
"""
|
||||
"""Load loras."""
|
||||
|
||||
loras = unwrap(kwargs.get("loras"), [])
|
||||
|
||||
@@ -777,9 +740,7 @@ class ExllamaV2Container:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
async def unload(self, loras_only: bool = False, **kwargs):
|
||||
"""
|
||||
Free all VRAM resources used by this model
|
||||
"""
|
||||
"""Free all VRAM resources used by the model (and loras)."""
|
||||
|
||||
# Shutdown immediately unloads and bypasses all locks
|
||||
do_shutdown = kwargs.get("shutdown")
|
||||
@@ -836,7 +797,7 @@ class ExllamaV2Container:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string"""
|
||||
"""Wrapper to encode tokens from a text string."""
|
||||
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
@@ -888,7 +849,7 @@ class ExllamaV2Container:
|
||||
async def generate(
|
||||
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
|
||||
):
|
||||
"""Generate a response to a prompt"""
|
||||
"""Generate a response to a prompt."""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
@@ -939,7 +900,11 @@ class ExllamaV2Container:
|
||||
return joined_generation
|
||||
|
||||
def check_unsupported_settings(self, **kwargs):
|
||||
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
|
||||
"""
|
||||
Check and warn the user if a sampler is unsupported.
|
||||
|
||||
Meant for dev wheels!
|
||||
"""
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user