mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 17:51:36 +00:00
Model: Add speculative decoding support via config
Speculative decoding makes use of draft models that ingest the prompt before forwarding it to the main model. Add options in the config to support this. API options will occur in a different commit. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
16
main.py
16
main.py
@@ -56,7 +56,7 @@ async def list_models():
|
||||
else:
|
||||
model_path = pathlib.Path("models")
|
||||
|
||||
models = get_model_list(model_path)
|
||||
models = get_model_list(model_path.resolve())
|
||||
|
||||
return models
|
||||
|
||||
@@ -76,7 +76,7 @@ async def load_model(data: ModelLoadRequest):
|
||||
def generator():
|
||||
global model_container
|
||||
|
||||
model_config = config.get("model", {})
|
||||
model_config = config.get("model") or {}
|
||||
if "model_dir" in model_config:
|
||||
model_path = pathlib.Path(model_config["model_dir"])
|
||||
else:
|
||||
@@ -84,7 +84,7 @@ async def load_model(data: ModelLoadRequest):
|
||||
|
||||
model_path = model_path / data.name
|
||||
|
||||
model_container = ModelContainer(model_path, False, **data.dict())
|
||||
model_container = ModelContainer(model_path.resolve(), False, **data.dict())
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
@@ -217,12 +217,12 @@ if __name__ == "__main__":
|
||||
|
||||
# If an initial model name is specified, create a container and load the model
|
||||
|
||||
model_config = config.get("model", {})
|
||||
model_config = config.get("model") or {}
|
||||
if "model_name" in model_config:
|
||||
model_path = pathlib.Path(model_config.get("model_dir", "models"))
|
||||
model_path = model_path / model_config["model_name"]
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = model_path / model_config.get("model_name")
|
||||
|
||||
model_container = ModelContainer(model_path, False, **model_config)
|
||||
model_container = ModelContainer(model_path.resolve(), False, **model_config)
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
@@ -233,7 +233,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
loading_bar.next()
|
||||
|
||||
network_config = config.get("network", {})
|
||||
network_config = config.get("network") or {}
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=network_config.get("host", "127.0.0.1"),
|
||||
|
||||
Reference in New Issue
Block a user