mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-30 03:01:44 +00:00
Model: Add CFG support
CFG, or classifier-free guidance helps push a model in different directions based on what the user provides. Currently, CFG is ignored if the negative prompt is blank (it shouldn't be used in that way anyways). Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -75,6 +75,7 @@ class CommonCompletionRequest(BaseModel):
|
|||||||
add_bos_token: Optional[bool] = True
|
add_bos_token: Optional[bool] = True
|
||||||
ban_eos_token: Optional[bool] = False
|
ban_eos_token: Optional[bool] = False
|
||||||
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
|
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
|
||||||
|
negative_prompt: Optional[str] = None
|
||||||
|
|
||||||
# Aliased variables
|
# Aliased variables
|
||||||
penalty_range: Optional[int] = Field(
|
penalty_range: Optional[int] = Field(
|
||||||
@@ -86,6 +87,10 @@ class CommonCompletionRequest(BaseModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg_scale: Optional[float] = Field(
|
||||||
|
default=1.0, validation_alias=AliasChoices("cfg_scale", "guidance_scale")
|
||||||
|
)
|
||||||
|
|
||||||
def to_gen_params(self):
|
def to_gen_params(self):
|
||||||
"""Converts to internal generation parameters."""
|
"""Converts to internal generation parameters."""
|
||||||
# Convert stop to an array of strings
|
# Convert stop to an array of strings
|
||||||
@@ -115,4 +120,6 @@ class CommonCompletionRequest(BaseModel):
|
|||||||
"mirostat": self.mirostat_mode == 2,
|
"mirostat": self.mirostat_mode == 2,
|
||||||
"mirostat_tau": self.mirostat_tau,
|
"mirostat_tau": self.mirostat_tau,
|
||||||
"mirostat_eta": self.mirostat_eta,
|
"mirostat_eta": self.mirostat_eta,
|
||||||
|
"cfg_scale": self.cfg_scale,
|
||||||
|
"negative_prompt": self.negative_prompt,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ class ModelLoadRequest(BaseModel):
|
|||||||
cache_mode: Optional[str] = "FP16"
|
cache_mode: Optional[str] = "FP16"
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
num_experts_per_token: Optional[int] = None
|
num_experts_per_token: Optional[int] = None
|
||||||
|
use_cfg: Optional[bool] = None
|
||||||
draft: Optional[DraftModelLoadRequest] = None
|
draft: Optional[DraftModelLoadRequest] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
5
args.py
5
args.py
@@ -106,6 +106,11 @@ def add_model_args(parser: argparse.ArgumentParser):
|
|||||||
type=int,
|
type=int,
|
||||||
help="Number of experts to use per token in MoE models",
|
help="Number of experts to use per token in MoE models",
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--use-cfg",
|
||||||
|
type=str_to_bool,
|
||||||
|
help="Enables CFG support",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_logging_args(parser: argparse.ArgumentParser):
|
def add_logging_args(parser: argparse.ArgumentParser):
|
||||||
|
|||||||
@@ -85,6 +85,10 @@ model:
|
|||||||
# NOTE: For MoE models (ex. Mixtral) only!
|
# NOTE: For MoE models (ex. Mixtral) only!
|
||||||
#num_experts_per_token:
|
#num_experts_per_token:
|
||||||
|
|
||||||
|
# Enables CFG support (default: False)
|
||||||
|
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
|
||||||
|
use_cfg: False
|
||||||
|
|
||||||
# Options for draft models (speculative decoding). This will use more VRAM!
|
# Options for draft models (speculative decoding). This will use more VRAM!
|
||||||
#draft:
|
#draft:
|
||||||
# Overrides the directory to look for draft (default: models)
|
# Overrides the directory to look for draft (default: models)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Functions for logging generation events.
|
Functions for logging generation events.
|
||||||
"""
|
"""
|
||||||
from typing import Dict
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from logger import init_logger
|
from logger import init_logger
|
||||||
|
|
||||||
@@ -53,12 +53,16 @@ def log_generation_params(**kwargs):
|
|||||||
logger.info(f"Generation options: {kwargs}\n")
|
logger.info(f"Generation options: {kwargs}\n")
|
||||||
|
|
||||||
|
|
||||||
def log_prompt(prompt: str):
|
def log_prompt(prompt: str, negative_prompt: Optional[str]):
|
||||||
"""Logs the prompt to console."""
|
"""Logs the prompt to console."""
|
||||||
if PREFERENCES.prompt:
|
if PREFERENCES.prompt:
|
||||||
formatted_prompt = "\n" + prompt
|
formatted_prompt = "\n" + prompt
|
||||||
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
|
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
|
||||||
|
|
||||||
|
if negative_prompt:
|
||||||
|
formatted_negative_prompt = "\n" + negative_prompt
|
||||||
|
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")
|
||||||
|
|
||||||
|
|
||||||
def log_response(response: str):
|
def log_response(response: str):
|
||||||
"""Logs the response to console."""
|
"""Logs the response to console."""
|
||||||
|
|||||||
89
model.py
89
model.py
@@ -47,6 +47,7 @@ class ModelContainer:
|
|||||||
cache_fp8: bool = False
|
cache_fp8: bool = False
|
||||||
gpu_split_auto: bool = True
|
gpu_split_auto: bool = True
|
||||||
gpu_split: Optional[list] = None
|
gpu_split: Optional[list] = None
|
||||||
|
use_cfg: bool = False
|
||||||
|
|
||||||
active_loras: List[ExLlamaV2Lora] = []
|
active_loras: List[ExLlamaV2Lora] = []
|
||||||
|
|
||||||
@@ -95,6 +96,8 @@ class ModelContainer:
|
|||||||
tensors, per device
|
tensors, per device
|
||||||
'no_flash_attn' (bool): Turns off flash attention
|
'no_flash_attn' (bool): Turns off flash attention
|
||||||
(increases vram usage) (default: False)
|
(increases vram usage) (default: False)
|
||||||
|
'use_cfg" (bool): Enables CFG support. Disables flash attention
|
||||||
|
(default: False)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.quiet = quiet
|
self.quiet = quiet
|
||||||
@@ -135,8 +138,18 @@ class ModelContainer:
|
|||||||
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Turn off flash attention?
|
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
|
||||||
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
|
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"CFG is not supported by the currently installed ExLlamaV2 version."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Turn off flash attention if CFG is on
|
||||||
|
# Workaround until batched FA2 is fixed in exllamav2 upstream
|
||||||
|
self.config.no_flash_attn = (
|
||||||
|
True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
|
||||||
|
)
|
||||||
|
|
||||||
# low_mem is currently broken in exllamav2. Don't use it until it's
|
# low_mem is currently broken in exllamav2. Don't use it until it's
|
||||||
# fixed.
|
# fixed.
|
||||||
@@ -348,10 +361,15 @@ class ModelContainer:
|
|||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
|
batch_size = 2 if self.use_cfg else 1
|
||||||
if self.cache_fp8:
|
if self.cache_fp8:
|
||||||
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
|
self.cache = ExLlamaV2Cache_8bit(
|
||||||
|
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
|
self.cache = ExLlamaV2Cache(
|
||||||
|
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
if self.gpu_split_auto:
|
if self.gpu_split_auto:
|
||||||
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
||||||
@@ -561,6 +579,19 @@ class ModelContainer:
|
|||||||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||||
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
||||||
|
|
||||||
|
# Set CFG scale and negative prompt
|
||||||
|
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
|
||||||
|
negative_prompt = None
|
||||||
|
if cfg_scale not in [None, 1.0]:
|
||||||
|
if self.use_cfg:
|
||||||
|
gen_settings.cfg_scale = cfg_scale
|
||||||
|
negative_prompt = kwargs.get("negative_prompt")
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
"CFG is currently disabled. "
|
||||||
|
+ "Please reload your model with use_cfg = True.",
|
||||||
|
)
|
||||||
|
|
||||||
gen_settings.token_presence_penalty = unwrap(
|
gen_settings.token_presence_penalty = unwrap(
|
||||||
kwargs.get("presence_penalty"), 0.0
|
kwargs.get("presence_penalty"), 0.0
|
||||||
)
|
)
|
||||||
@@ -635,7 +666,7 @@ class ModelContainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Log prompt to console
|
# Log prompt to console
|
||||||
log_prompt(prompt)
|
log_prompt(prompt, negative_prompt)
|
||||||
|
|
||||||
# Set logit bias
|
# Set logit bias
|
||||||
if logit_bias:
|
if logit_bias:
|
||||||
@@ -663,8 +694,18 @@ class ModelContainer:
|
|||||||
self.generator.set_stop_conditions(stop_conditions)
|
self.generator.set_stop_conditions(stop_conditions)
|
||||||
|
|
||||||
# Tokenized context
|
# Tokenized context
|
||||||
ids = self.tokenizer.encode(
|
ids, offsets = self.tokenizer.encode(
|
||||||
prompt, add_bos=add_bos_token, encode_special_tokens=True
|
[prompt, negative_prompt]
|
||||||
|
if negative_prompt and gen_settings.cfg_scale not in [None, 1.0]
|
||||||
|
else prompt,
|
||||||
|
add_bos=add_bos_token,
|
||||||
|
encode_special_tokens=True,
|
||||||
|
return_offsets=True,
|
||||||
|
)
|
||||||
|
mask = (
|
||||||
|
self.tokenizer.padding_mask(ids)
|
||||||
|
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
context_len = len(ids[0])
|
context_len = len(ids[0])
|
||||||
|
|
||||||
@@ -683,7 +724,7 @@ class ModelContainer:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
last_chunk_time = start_time
|
last_chunk_time = start_time
|
||||||
|
|
||||||
save_tokens = torch.empty((1, 0), dtype=torch.bool)
|
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
|
||||||
chunk_buffer = ""
|
chunk_buffer = ""
|
||||||
chunk_tokens = 0
|
chunk_tokens = 0
|
||||||
|
|
||||||
@@ -691,17 +732,31 @@ class ModelContainer:
|
|||||||
# Ingest prompt
|
# Ingest prompt
|
||||||
if chunk_tokens == 0:
|
if chunk_tokens == 0:
|
||||||
ids = torch.cat((ids, save_tokens), dim=-1)
|
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||||
save_tokens = torch.empty((1, 0), dtype=torch.bool)
|
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
|
||||||
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
|
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
|
||||||
active_ids = ids[:, max(0, overflow) :]
|
active_ids = ids[:, max(0, overflow) :]
|
||||||
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
||||||
|
|
||||||
self.generator.begin_stream(
|
# Split for exllama versions that have CFG
|
||||||
active_ids,
|
if self.use_cfg:
|
||||||
gen_settings,
|
self.generator.begin_stream(
|
||||||
token_healing=token_healing,
|
active_ids,
|
||||||
loras=self.active_loras,
|
gen_settings,
|
||||||
)
|
token_healing=token_healing,
|
||||||
|
loras=self.active_loras,
|
||||||
|
input_mask=mask,
|
||||||
|
position_offsets=offsets,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.generator.begin_stream(
|
||||||
|
active_ids,
|
||||||
|
gen_settings,
|
||||||
|
token_healing=token_healing,
|
||||||
|
loras=self.active_loras,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset offsets for subsequent passes if the context is truncated
|
||||||
|
offsets = None
|
||||||
|
|
||||||
if auto_scale_penalty_range:
|
if auto_scale_penalty_range:
|
||||||
gen_settings.token_repetition_range = generated_tokens
|
gen_settings.token_repetition_range = generated_tokens
|
||||||
@@ -714,7 +769,9 @@ class ModelContainer:
|
|||||||
ids[:, -1] = self.generator.sequence_ids[:, -2]
|
ids[:, -1] = self.generator.sequence_ids[:, -2]
|
||||||
token_healing = False
|
token_healing = False
|
||||||
|
|
||||||
save_tokens = torch.cat((save_tokens, tokens), dim=-1)
|
save_tokens = torch.cat(
|
||||||
|
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
|
||||||
|
)
|
||||||
chunk_buffer += chunk
|
chunk_buffer += chunk
|
||||||
|
|
||||||
generated_tokens += 1
|
generated_tokens += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user