Model: Add CFG support

CFG, or classifier-free guidance helps push a model in different
directions based on what the user provides.

Currently, CFG is ignored if the negative prompt is blank (it shouldn't
be used in that way anyways).

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-02 01:09:26 -05:00
committed by Brian Dashore
parent bb7a8e4614
commit b378773d0a
6 changed files with 96 additions and 18 deletions

View File

@@ -75,6 +75,7 @@ class CommonCompletionRequest(BaseModel):
add_bos_token: Optional[bool] = True add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]]) logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
negative_prompt: Optional[str] = None
# Aliased variables # Aliased variables
penalty_range: Optional[int] = Field( penalty_range: Optional[int] = Field(
@@ -86,6 +87,10 @@ class CommonCompletionRequest(BaseModel):
), ),
) )
cfg_scale: Optional[float] = Field(
default=1.0, validation_alias=AliasChoices("cfg_scale", "guidance_scale")
)
def to_gen_params(self): def to_gen_params(self):
"""Converts to internal generation parameters.""" """Converts to internal generation parameters."""
# Convert stop to an array of strings # Convert stop to an array of strings
@@ -115,4 +120,6 @@ class CommonCompletionRequest(BaseModel):
"mirostat": self.mirostat_mode == 2, "mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau, "mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta, "mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
} }

View File

@@ -83,6 +83,7 @@ class ModelLoadRequest(BaseModel):
cache_mode: Optional[str] = "FP16" cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
draft: Optional[DraftModelLoadRequest] = None draft: Optional[DraftModelLoadRequest] = None

View File

@@ -106,6 +106,11 @@ def add_model_args(parser: argparse.ArgumentParser):
type=int, type=int,
help="Number of experts to use per token in MoE models", help="Number of experts to use per token in MoE models",
) )
model_group.add_argument(
"--use-cfg",
type=str_to_bool,
help="Enables CFG support",
)
def add_logging_args(parser: argparse.ArgumentParser): def add_logging_args(parser: argparse.ArgumentParser):

View File

@@ -85,6 +85,10 @@ model:
# NOTE: For MoE models (ex. Mixtral) only! # NOTE: For MoE models (ex. Mixtral) only!
#num_experts_per_token: #num_experts_per_token:
# Enables CFG support (default: False)
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
use_cfg: False
# Options for draft models (speculative decoding). This will use more VRAM! # Options for draft models (speculative decoding). This will use more VRAM!
#draft: #draft:
# Overrides the directory to look for draft (default: models) # Overrides the directory to look for draft (default: models)

View File

@@ -1,8 +1,8 @@
""" """
Functions for logging generation events. Functions for logging generation events.
""" """
from typing import Dict
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, Optional
from logger import init_logger from logger import init_logger
@@ -53,12 +53,16 @@ def log_generation_params(**kwargs):
logger.info(f"Generation options: {kwargs}\n") logger.info(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str): def log_prompt(prompt: str, negative_prompt: Optional[str]):
"""Logs the prompt to console.""" """Logs the prompt to console."""
if PREFERENCES.prompt: if PREFERENCES.prompt:
formatted_prompt = "\n" + prompt formatted_prompt = "\n" + prompt
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n") logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
if negative_prompt:
formatted_negative_prompt = "\n" + negative_prompt
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")
def log_response(response: str): def log_response(response: str):
"""Logs the response to console.""" """Logs the response to console."""

View File

@@ -47,6 +47,7 @@ class ModelContainer:
cache_fp8: bool = False cache_fp8: bool = False
gpu_split_auto: bool = True gpu_split_auto: bool = True
gpu_split: Optional[list] = None gpu_split: Optional[list] = None
use_cfg: bool = False
active_loras: List[ExLlamaV2Lora] = [] active_loras: List[ExLlamaV2Lora] = []
@@ -95,6 +96,8 @@ class ModelContainer:
tensors, per device tensors, per device
'no_flash_attn' (bool): Turns off flash attention 'no_flash_attn' (bool): Turns off flash attention
(increases vram usage) (default: False) (increases vram usage) (default: False)
'use_cfg" (bool): Enables CFG support. Disables flash attention
(default: False)
""" """
self.quiet = quiet self.quiet = quiet
@@ -135,8 +138,18 @@ class ModelContainer:
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len) kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
) )
# Turn off flash attention? if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False) self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
else:
logger.warning(
"CFG is not supported by the currently installed ExLlamaV2 version."
)
# Turn off flash attention if CFG is on
# Workaround until batched FA2 is fixed in exllamav2 upstream
self.config.no_flash_attn = (
True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
)
# low_mem is currently broken in exllamav2. Don't use it until it's # low_mem is currently broken in exllamav2. Don't use it until it's
# fixed. # fixed.
@@ -348,10 +361,15 @@ class ModelContainer:
if isinstance(value, str): if isinstance(value, str):
yield value yield value
batch_size = 2 if self.use_cfg else 1
if self.cache_fp8: if self.cache_fp8:
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto) self.cache = ExLlamaV2Cache_8bit(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
else: else:
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto) self.cache = ExLlamaV2Cache(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
if self.gpu_split_auto: if self.gpu_split_auto:
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16 reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
@@ -561,6 +579,19 @@ class ModelContainer:
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1) gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
# Set CFG scale and negative prompt
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.use_cfg:
gen_settings.cfg_scale = cfg_scale
negative_prompt = kwargs.get("negative_prompt")
else:
logger.warn(
"CFG is currently disabled. "
+ "Please reload your model with use_cfg = True.",
)
gen_settings.token_presence_penalty = unwrap( gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0 kwargs.get("presence_penalty"), 0.0
) )
@@ -635,7 +666,7 @@ class ModelContainer:
) )
# Log prompt to console # Log prompt to console
log_prompt(prompt) log_prompt(prompt, negative_prompt)
# Set logit bias # Set logit bias
if logit_bias: if logit_bias:
@@ -663,8 +694,18 @@ class ModelContainer:
self.generator.set_stop_conditions(stop_conditions) self.generator.set_stop_conditions(stop_conditions)
# Tokenized context # Tokenized context
ids = self.tokenizer.encode( ids, offsets = self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True [prompt, negative_prompt]
if negative_prompt and gen_settings.cfg_scale not in [None, 1.0]
else prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
return_offsets=True,
)
mask = (
self.tokenizer.padding_mask(ids)
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
else None
) )
context_len = len(ids[0]) context_len = len(ids[0])
@@ -683,7 +724,7 @@ class ModelContainer:
start_time = time.time() start_time = time.time()
last_chunk_time = start_time last_chunk_time = start_time
save_tokens = torch.empty((1, 0), dtype=torch.bool) save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
chunk_buffer = "" chunk_buffer = ""
chunk_tokens = 0 chunk_tokens = 0
@@ -691,17 +732,31 @@ class ModelContainer:
# Ingest prompt # Ingest prompt
if chunk_tokens == 0: if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1) ids = torch.cat((ids, save_tokens), dim=-1)
save_tokens = torch.empty((1, 0), dtype=torch.bool) save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
active_ids = ids[:, max(0, overflow) :] active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
self.generator.begin_stream( # Split for exllama versions that have CFG
active_ids, if self.use_cfg:
gen_settings, self.generator.begin_stream(
token_healing=token_healing, active_ids,
loras=self.active_loras, gen_settings,
) token_healing=token_healing,
loras=self.active_loras,
input_mask=mask,
position_offsets=offsets,
)
else:
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)
# Reset offsets for subsequent passes if the context is truncated
offsets = None
if auto_scale_penalty_range: if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens gen_settings.token_repetition_range = generated_tokens
@@ -714,7 +769,9 @@ class ModelContainer:
ids[:, -1] = self.generator.sequence_ids[:, -2] ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False token_healing = False
save_tokens = torch.cat((save_tokens, tokens), dim=-1) save_tokens = torch.cat(
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
)
chunk_buffer += chunk chunk_buffer += chunk
generated_tokens += 1 generated_tokens += 1