Tree: Use unwrap and coalesce for optional handling

Python doesn't have proper handling of optionals. The only way to
handle them is checking via an if statement if the value is None or
by using the "or" keyword to unwrap optionals.

Previously, I used the "or" method to unwrap, but this caused issues
due to falsy values falling back to the default. This is especially
the case with booleans were "False" changed to "True".

Instead, add two new functions: unwrap and coalesce. Both function
to properly implement a functional way of "None" coalescing.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-09 21:52:17 -05:00
parent 7380a3b79a
commit 5ae2a91c04
5 changed files with 83 additions and 68 deletions

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Union from typing import List, Dict, Optional, Union
from utils import coalesce
class LogProbs(BaseModel): class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
@@ -83,7 +84,7 @@ class CommonCompletionRequest(BaseModel):
"min_p": self.min_p, "min_p": self.min_p,
"tfs": self.tfs, "tfs": self.tfs,
"repetition_penalty": self.repetition_penalty, "repetition_penalty": self.repetition_penalty,
"repetition_range": self.repetition_range or self.repetition_penalty_range or -1, "repetition_range": coalesce(self.repetition_range, self.repetition_penalty_range, -1),
"repetition_decay": self.repetition_decay, "repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2, "mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau, "mirostat_tau": self.mirostat_tau,

View File

@@ -12,6 +12,7 @@ from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard from OAI.types.model import ModelList, ModelCard
from packaging import version from packaging import version
from typing import Optional, List, Dict from typing import Optional, List, Dict
from utils import unwrap
# Check fastchat # Check fastchat
try: try:
@@ -30,7 +31,7 @@ def create_completion_response(text: str, prompt_tokens: int, completion_tokens:
response = CompletionResponse( response = CompletionResponse(
choices = [choice], choices = [choice],
model = model_name or "", model = unwrap(model_name, ""),
usage = UsageStats(prompt_tokens = prompt_tokens, usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens, completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens) total_tokens = prompt_tokens + completion_tokens)
@@ -51,7 +52,7 @@ def create_chat_completion_response(text: str, prompt_tokens: int, completion_to
response = ChatCompletionResponse( response = ChatCompletionResponse(
choices = [choice], choices = [choice],
model = model_name or "", model = unwrap(model_name, ""),
usage = UsageStats(prompt_tokens = prompt_tokens, usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens, completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens) total_tokens = prompt_tokens + completion_tokens)
@@ -80,7 +81,7 @@ def create_chat_completion_stream_chunk(const_id: str,
chunk = ChatCompletionStreamChunk( chunk = ChatCompletionStreamChunk(
id = const_id, id = const_id,
choices = [choice], choices = [choice],
model = model_name or "" model = unwrap(model_name, "")
) )
return chunk return chunk

46
main.py
View File

@@ -28,7 +28,7 @@ from OAI.utils import (
create_chat_completion_stream_chunk create_chat_completion_stream_chunk
) )
from typing import Optional from typing import Optional
from utils import get_generator_error, get_sse_packet, load_progress from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from uuid import uuid4 from uuid import uuid4
app = FastAPI() app = FastAPI()
@@ -54,17 +54,17 @@ app.add_middleware(
@app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models(): async def list_models():
model_config = config.get("model") or {} model_config = unwrap(config.get("model"), {})
if "model_dir" in model_config: if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"]) model_path = pathlib.Path(model_config["model_dir"])
else: else:
model_path = pathlib.Path("models") model_path = pathlib.Path("models")
draft_config = model_config.get("draft") or {} draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = draft_config.get("draft_model_dir") draft_model_dir = draft_config.get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir) models = get_model_list(model_path.resolve(), draft_model_dir)
if model_config.get("use_dummy_models") or False: if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id = "gpt-3.5-turbo")) models.data.insert(0, ModelCard(id = "gpt-3.5-turbo"))
return models return models
@@ -89,19 +89,19 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name: if not data.name:
raise HTTPException(400, "model_name not found.") raise HTTPException(400, "model_name not found.")
model_config = config.get("model") or {} model_config = unwrap(config.get("model"), {})
model_path = pathlib.Path(model_config.get("model_dir") or "models") model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / data.name model_path = model_path / data.name
load_data = data.dict() load_data = data.dict()
# TODO: Add API exception if draft directory isn't found # TODO: Add API exception if draft directory isn't found
draft_config = model_config.get("draft") or {} draft_config = unwrap(model_config.get("draft"), {})
if data.draft: if data.draft:
if not data.draft.draft_model_name: if not data.draft.draft_model_name:
raise HTTPException(400, "draft_model_name was not found inside the draft object.") raise HTTPException(400, "draft_model_name was not found inside the draft object.")
load_data["draft"]["draft_model_dir"] = draft_config.get("draft_model_dir") or "models" load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models")
if not model_path.exists(): if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?") raise HTTPException(400, "model_path does not exist. Check model_name?")
@@ -167,9 +167,9 @@ async def unload_model():
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_all_loras(): async def get_all_loras():
model_config = config.get("model") or {} model_config = unwrap(config.get("model"), {})
lora_config = model_config.get("lora") or {} lora_config = unwrap(model_config.get("lora"), {})
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras") lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve()) loras = get_lora_list(lora_path.resolve())
@@ -196,9 +196,9 @@ async def load_model(data: LoraLoadRequest):
if not data.loras: if not data.loras:
raise HTTPException(400, "List of loras to load is not found.") raise HTTPException(400, "List of loras to load is not found.")
model_config = config.get("model") or {} model_config = unwrap(config.get("model"), {})
lora_config = model_config.get("lora") or {} lora_config = unwrap(model_config.get("lora"), {})
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras") lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
if not lora_dir.exists(): if not lora_dir.exists():
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?") raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
@@ -208,8 +208,8 @@ async def load_model(data: LoraLoadRequest):
result = model_container.load_loras(lora_dir, **data.dict()) result = model_container.load_loras(lora_dir, **data.dict())
return LoraLoadResponse( return LoraLoadResponse(
success = result.get("success") or [], success = unwrap(result.get("success"), []),
failure = result.get("failure") or [] failure = unwrap(result.get("failure"), [])
) )
# Unload lora endpoint # Unload lora endpoint
@@ -234,7 +234,7 @@ async def encode_tokens(data: TokenEncodeRequest):
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def decode_tokens(data: TokenDecodeRequest): async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params()) message = model_container.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text = message or "") response = TokenDecodeResponse(text = unwrap(message, ""))
return response return response
@@ -337,7 +337,7 @@ if __name__ == "__main__":
# Load from YAML config. Possibly add a config -> kwargs conversion function # Load from YAML config. Possibly add a config -> kwargs conversion function
try: try:
with open('config.yml', 'r', encoding = "utf8") as config_file: with open('config.yml', 'r', encoding = "utf8") as config_file:
config = yaml.safe_load(config_file) or {} config = unwrap(yaml.safe_load(config_file), {})
except Exception as e: except Exception as e:
print( print(
"The YAML config couldn't load because of the following error:", "The YAML config couldn't load because of the following error:",
@@ -348,10 +348,10 @@ if __name__ == "__main__":
# If an initial model name is specified, create a container and load the model # If an initial model name is specified, create a container and load the model
model_config = config.get("model") or {} model_config = unwrap(config.get("model"), {})
if "model_name" in model_config: if "model_name" in model_config:
# TODO: Move this to model_container # TODO: Move this to model_container
model_path = pathlib.Path(model_config.get("model_dir") or "models") model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_config.get("model_name") model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path.resolve(), False, **model_config) model_container = ModelContainer(model_path.resolve(), False, **model_config)
@@ -366,12 +366,12 @@ if __name__ == "__main__":
loading_bar.next() loading_bar.next()
# Load loras # Load loras
lora_config = model_config.get("lora") or {} lora_config = unwrap(model_config.get("lora"), {})
if "loras" in lora_config: if "loras" in lora_config:
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras") lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model_container.load_loras(lora_dir.resolve(), **lora_config) model_container.load_loras(lora_dir.resolve(), **lora_config)
network_config = config.get("network") or {} network_config = unwrap(config.get("network"), {})
uvicorn.run( uvicorn.run(
app, app,
host=network_config.get("host", "127.0.0.1"), host=network_config.get("host", "127.0.0.1"),

View File

@@ -13,6 +13,7 @@ from exllamav2.generator import(
ExLlamaV2Sampler ExLlamaV2Sampler
) )
from typing import List, Optional, Union from typing import List, Optional, Union
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split # Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2 auto_split_reserve_bytes = 96 * 1024**2
@@ -30,7 +31,7 @@ class ModelContainer:
cache_fp8: bool = False cache_fp8: bool = False
gpu_split_auto: bool = True gpu_split_auto: bool = True
gpu_split: list or None = None gpu_split: Optional[list] = None
active_loras: List[ExLlamaV2Lora] = [] active_loras: List[ExLlamaV2Lora] = []
@@ -68,7 +69,7 @@ class ModelContainer:
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8" self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
self.gpu_split = kwargs.get("gpu_split") self.gpu_split = kwargs.get("gpu_split")
self.gpu_split_auto = kwargs.get("gpu_split_auto") or True self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
self.config = ExLlamaV2Config() self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve()) self.config.model_dir = str(model_directory.resolve())
@@ -78,14 +79,14 @@ class ModelContainer:
base_seq_len = self.config.max_seq_len base_seq_len = self.config.max_seq_len
# Then override the max_seq_len if present # Then override the max_seq_len if present
self.config.max_seq_len = kwargs.get("max_seq_len") or 4096 self.config.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
self.config.scale_pos_emb = kwargs.get("rope_scale") or 1.0 self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
# Automatically calculate rope alpha # Automatically calculate rope alpha
self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len) self.config.scale_alpha_value = unwrap(kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len))
# Turn off flash attention? # Turn off flash attention?
self.config.no_flash_attn = kwargs.get("no_flash_attn") or False self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attn"), False)
# low_mem is currently broken in exllamav2. Don't use it until it's fixed. # low_mem is currently broken in exllamav2. Don't use it until it's fixed.
""" """
@@ -93,11 +94,11 @@ class ModelContainer:
self.config.set_low_mem() self.config.set_low_mem()
""" """
chunk_size = min(kwargs.get("chunk_size") or 2048, self.config.max_seq_len) chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
self.config.max_input_len = chunk_size self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2 self.config.max_attn_size = chunk_size ** 2
draft_args = kwargs.get("draft") or {} draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name") draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name enable_draft = draft_args and draft_model_name
@@ -109,14 +110,14 @@ class ModelContainer:
if enable_draft: if enable_draft:
self.draft_config = ExLlamaV2Config() self.draft_config = ExLlamaV2Config()
draft_model_path = pathlib.Path(draft_args.get("draft_model_dir") or "models") draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
draft_model_path = draft_model_path / draft_model_name draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare() self.draft_config.prepare()
self.draft_config.scale_pos_emb = draft_args.get("draft_rope_scale") or 1.0 self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
self.draft_config.scale_alpha_value = draft_args.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) self.draft_config.scale_alpha_value = unwrap(draft_args.get("draft_rope_alpha"), self.calculate_rope_alpha(self.draft_config.max_seq_len))
self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.max_seq_len = self.config.max_seq_len
if "chunk_size" in kwargs: if "chunk_size" in kwargs:
@@ -151,13 +152,13 @@ class ModelContainer:
Load loras Load loras
""" """
loras = kwargs.get("loras") or [] loras = unwrap(kwargs.get("loras"), [])
success: List[str] = [] success: List[str] = []
failure: List[str] = [] failure: List[str] = []
for lora in loras: for lora in loras:
lora_name = lora.get("name") or None lora_name = lora.get("name")
lora_scaling = lora.get("scaling") or 1.0 lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None: if lora_name is None:
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.") print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
@@ -265,13 +266,13 @@ class ModelContainer:
# Assume token encoding # Assume token encoding
return self.tokenizer.encode( return self.tokenizer.encode(
text, text,
add_bos = kwargs.get("add_bos_token") or True, add_bos = unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens = kwargs.get("encode_special_tokens") or True encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True)
) )
if ids: if ids:
# Assume token decoding # Assume token decoding
ids = torch.tensor([ids]) ids = torch.tensor([ids])
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens") or True)[0] return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
def generate(self, prompt: str, **kwargs): def generate(self, prompt: str, **kwargs):
@@ -311,10 +312,10 @@ class ModelContainer:
""" """
token_healing = kwargs.get("token_healing") or False token_healing = unwrap(kwargs.get("token_healing"), False)
max_tokens = kwargs.get("max_tokens") or 150 max_tokens = unwrap(kwargs.get("max_tokens"), 150)
stream_interval = kwargs.get("stream_interval") or 0 stream_interval = unwrap(kwargs.get("stream_interval"), 0)
generate_window = min(kwargs.get("generate_window") or 512, max_tokens) generate_window = min(unwrap(kwargs.get("generate_window"), 512), max_tokens)
# Sampler settings # Sampler settings
@@ -322,42 +323,43 @@ class ModelContainer:
# Warn of unsupported settings if the setting is enabled # Warn of unsupported settings if the setting is enabled
if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"): if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"):
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling") print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling") print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)") print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"): if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"):
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last") print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
#Apply settings #Apply settings
gen_settings.temperature = kwargs.get("temperature") or 1.0 gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = kwargs.get("temperature_last") or False gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
gen_settings.top_k = kwargs.get("top_k") or 0 gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
gen_settings.top_p = kwargs.get("top_p") or 1.0 gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
gen_settings.min_p = kwargs.get("min_p") or 0.0 gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
gen_settings.tfs = kwargs.get("tfs") or 1.0 gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
gen_settings.typical = kwargs.get("typical") or 1.0 gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
gen_settings.mirostat = kwargs.get("mirostat") or False gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
# Default tau and eta fallbacks don't matter if mirostat is off # Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = kwargs.get("mirostat_tau") or 1.5 gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1 gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0 gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0)
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len)
# Always make sure the fallback is 0 if range < 0 # Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed fallback # It's technically fine to use -1, but this just validates the passed fallback
# Always default to 0 if something goes wrong
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0 gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or [] stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
ban_eos_token = kwargs.get("ban_eos_token") or False ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
# Ban the EOS token if specified. If not, append to stop conditions as well. # Ban the EOS token if specified. If not, append to stop conditions as well.
@@ -383,7 +385,7 @@ class ModelContainer:
ids = self.tokenizer.encode( ids = self.tokenizer.encode(
prompt, prompt,
add_bos = kwargs.get("add_bos_token") or True, add_bos = unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens = True encode_special_tokens = True
) )
context_len = len(ids[0]) context_len = len(ids[0])

View File

@@ -30,3 +30,14 @@ def get_generator_error(message: str):
def get_sse_packet(json_data: str): def get_sse_packet(json_data: str):
return f"data: {json_data}\n\n" return f"data: {json_data}\n\n"
# Unwrap function for Optionals
def unwrap(wrapped, default = None):
if wrapped is None:
return default
else:
return wrapped
# Coalesce function for multiple unwraps
def coalesce(*args):
return next((arg for arg in args if arg is not None), None)