Model: Remove and format comments

The comment in __init__ was outdated and all the kwargs are the
config options anyways.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-08-23 21:33:18 -04:00
parent 80198ca056
commit 4958c06813

View File

@@ -101,51 +101,9 @@ class ExllamaV2Container:
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Create model container
Primary initializer for model container.
Args:
model_dir (int): Model directory containing config.json,
tokenizer.model etc.
quiet (bool): Suppress console output
load_progress_callback (function, optional): A function to call for
each module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int,
loading_draft: bool)
**kwargs:
`cache_mode` (str): Sets cache mode: "FP16"/"Q8"/"Q6"/"Q4"
(default: "FP16")
'max_seq_len' (int): Override model's default max sequence
length (default: 4096)
'cache_size' (int): Num of tokens to allocate space for in the k/v cache
(default: max_seq_len)
'rope_scale' (float): Set RoPE scaling factor for model
(default: 1.0)
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
(default: 1.0)
'prompt_template' (str): Manually sets the prompt template for
this model (default: None)
'chunk_size' (int): Sets the maximum chunk size for the model
(default: 2048)
Inferencing in chunks reduces overall VRAM overhead by
processing very long sequences in smaller batches. This
limits the size of temporary buffers needed for the hidden
state and attention weights.
'draft_model_dir' (str): Draft model directory
'draft_rope_scale' (float): Set RoPE scaling factor for draft
model (default: 1.0)
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
model. By default, the draft model's alpha value is
calculated automatically to scale to the size of the
full model.
'draft_cache_mode' (str): Sets draft cache mode: "FP16"/"Q8"/"Q6"/"Q4"
(default: "FP16")
'lora_dir' (str): LoRA directory
'loras' (list[dict]): List of loras to be loaded, consisting of
'name' and 'scaling'
'gpu_split_auto' (bool): Automatically split model across
available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some)
tensors, per device
Kwargs are located in config_sample.yml
"""
self.quiet = quiet
@@ -386,7 +344,7 @@ class ExllamaV2Container:
self.draft_config.max_attention_size = chunk_size**2
def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods"""
"""Tries to find a prompt template using various methods."""
logger.info("Attempting to load a prompt template if present.")
@@ -428,6 +386,7 @@ class ExllamaV2Container:
def calculate_rope_alpha(self, base_seq_len):
"""Calculate the rope alpha value for a given sequence length."""
ratio = self.config.max_seq_len / base_seq_len
# Default to a 1 alpha if the sequence length is ever less
@@ -504,7 +463,9 @@ class ExllamaV2Container:
Args:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
module loaded.
Prototype:
def progress(loaded_modules: int, total_modules: int)
"""
@@ -549,11 +510,13 @@ class ExllamaV2Container:
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
"""
Load model, generator function
Synchronous generator for loading.
Args:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
module loaded.
Prototype:
def progress(loaded_modules: int, total_modules: int)
Runs under a shared inference mode context.
@@ -695,6 +658,8 @@ class ExllamaV2Container:
)
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.model_loaded:
@@ -728,9 +693,7 @@ class ExllamaV2Container:
return unwrap(self.generator.generator.current_loras, [])
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
"""Load loras."""
loras = unwrap(kwargs.get("loras"), [])
@@ -777,9 +740,7 @@ class ExllamaV2Container:
self.load_condition.notify_all()
async def unload(self, loras_only: bool = False, **kwargs):
"""
Free all VRAM resources used by this model
"""
"""Free all VRAM resources used by the model (and loras)."""
# Shutdown immediately unloads and bypasses all locks
do_shutdown = kwargs.get("shutdown")
@@ -836,7 +797,7 @@ class ExllamaV2Container:
self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string"""
"""Wrapper to encode tokens from a text string."""
return (
self.tokenizer.encode(
@@ -888,7 +849,7 @@ class ExllamaV2Container:
async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
):
"""Generate a response to a prompt"""
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
@@ -939,7 +900,11 @@ class ExllamaV2Container:
return joined_generation
def check_unsupported_settings(self, **kwargs):
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
"""
Check and warn the user if a sampler is unsupported.
Meant for dev wheels!
"""
return kwargs