Model: Add speculative decoding support via config

Speculative decoding makes use of draft models that ingest the prompt
before forwarding it to the main model.

Add options in the config to support this. API options will occur
in a different commit.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-18 01:38:54 -05:00
parent 78a6587b95
commit 27ebec3b35
3 changed files with 38 additions and 14 deletions

16
main.py
View File

@@ -56,7 +56,7 @@ async def list_models():
else:
model_path = pathlib.Path("models")
models = get_model_list(model_path)
models = get_model_list(model_path.resolve())
return models
@@ -76,7 +76,7 @@ async def load_model(data: ModelLoadRequest):
def generator():
global model_container
model_config = config.get("model", {})
model_config = config.get("model") or {}
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
@@ -84,7 +84,7 @@ async def load_model(data: ModelLoadRequest):
model_path = model_path / data.name
model_container = ModelContainer(model_path, False, **data.dict())
model_container = ModelContainer(model_path.resolve(), False, **data.dict())
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@@ -217,12 +217,12 @@ if __name__ == "__main__":
# If an initial model name is specified, create a container and load the model
model_config = config.get("model", {})
model_config = config.get("model") or {}
if "model_name" in model_config:
model_path = pathlib.Path(model_config.get("model_dir", "models"))
model_path = model_path / model_config["model_name"]
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path, False, **model_config)
model_container = ModelContainer(model_path.resolve(), False, **model_config)
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@@ -233,7 +233,7 @@ if __name__ == "__main__":
else:
loading_bar.next()
network_config = config.get("network", {})
network_config = config.get("network") or {}
uvicorn.run(
app,
host=network_config.get("host", "127.0.0.1"),