Model: Fix gpu split params

GPU split auto is a bool and GPU split is an array of integers for
GBs to allocate per GPU.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-14 23:20:14 -05:00
parent ea91d17a11
commit 126afdfdc2
3 changed files with 6 additions and 5 deletions

View File

@@ -15,7 +15,8 @@ class ModelList(BaseModel):
class ModelLoadRequest(BaseModel): class ModelLoadRequest(BaseModel):
name: str name: str
max_seq_len: Optional[int] = 4096 max_seq_len: Optional[int] = 4096
gpu_split: Optional[str] = "auto" gpu_split_auto: Optional[bool] = True
gpu_split: Optional[List[float]] = Field(default_factory=list)
rope_scale: Optional[float] = 1.0 rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0
no_flash_attention: Optional[bool] = False no_flash_attention: Optional[bool] = False

View File

@@ -139,8 +139,6 @@ if __name__ == "__main__":
loading_bar.finish() loading_bar.finish()
else: else:
loading_bar.next() loading_bar.next()
print("Model successfully loaded.")
network_config = config["network"] network_config = config["network"]
uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug") uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug")

View File

@@ -55,7 +55,7 @@ class ModelContainer:
By default, the draft model's alpha value is calculated automatically to scale to the size of the By default, the draft model's alpha value is calculated automatically to scale to the size of the
full model. full model.
'gpu_split_auto' (bool): Automatically split model across available devices (default: True) 'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
'gpu_split' (list): Allocation for weights and (some) tensors, per device 'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) 'no_flash_attn' (bool): Turns off flash attention (increases vram usage)
""" """
@@ -63,7 +63,7 @@ class ModelContainer:
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8" self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
self.gpu_split = kwargs.get("gpu_split", None) self.gpu_split = kwargs.get("gpu_split", None)
self.gpu_split_auto = self.gpu_split == "auto" self.gpu_split_auto = kwargs.get("gpu_split_auto", True)
self.config = ExLlamaV2Config() self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve()) self.config.model_dir = str(model_directory.resolve())
@@ -177,6 +177,8 @@ class ModelContainer:
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer, self.draft_model, self.draft_cache) self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer, self.draft_model, self.draft_cache)
print("Model successfully loaded.")
def unload(self): def unload(self):
""" """