mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Model: Make loading use less VRAM
The model loader was using more VRAM on a single GPU compared to base exllamav2's loader. This was because single GPUs were running using the autosplit config which allocates an extra vram buffer for safe loading. Turn this off for single-GPU setups (and turn it off by default). This change should allow users to run models which require the entire card with hopefully faster T/s. For example, Mixtral with 3.75bpw increased from ~30T/s to 50T/s due to the extra vram headroom on Windows. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -70,7 +70,7 @@ class ModelLoadRequest(BaseModel):
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
gpu_split_auto: Optional[bool] = True
|
||||
gpu_split_auto: Optional[bool] = False
|
||||
gpu_split: Optional[List[float]] = Field(
|
||||
default_factory=list, examples=[[24.0, 20.0]]
|
||||
)
|
||||
|
||||
@@ -104,7 +104,7 @@ class ExllamaV2Container:
|
||||
|
||||
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
|
||||
self.gpu_split = kwargs.get("gpu_split")
|
||||
self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), False)
|
||||
|
||||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
@@ -347,16 +347,22 @@ class ExllamaV2Container:
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
# Load model
|
||||
self.model = ExLlamaV2(self.config)
|
||||
if not self.quiet:
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
|
||||
# Load model with manual split
|
||||
# Entrypoint for single GPU users
|
||||
if not self.gpu_split_auto:
|
||||
logger.info(
|
||||
"Loading with a manual GPU split (or a one GPU setup)"
|
||||
)
|
||||
|
||||
for value in self.model.load_gen(
|
||||
self.gpu_split, callback_gen=progress_callback
|
||||
self.gpu_split,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
if isinstance(value, str):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
batch_size = 2 if self.use_cfg else 1
|
||||
@@ -369,14 +375,19 @@ class ExllamaV2Container:
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
|
||||
)
|
||||
|
||||
# Load model with autosplit
|
||||
if self.gpu_split_auto:
|
||||
logger.info("Loading with autosplit")
|
||||
|
||||
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
||||
yield from self.model.load_autosplit_gen(
|
||||
for value in self.model.load_autosplit_gen(
|
||||
self.cache,
|
||||
reserve_vram=reserve,
|
||||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
)
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
# Test VRAM allocation with a full-length forward pass
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
@@ -395,6 +406,11 @@ class ExllamaV2Container:
|
||||
self.generator.return_probabilities = True
|
||||
self.generator.return_logits = True
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Model successfully loaded.")
|
||||
|
||||
def unload(self, loras_only: bool = False):
|
||||
|
||||
@@ -68,8 +68,9 @@ model:
|
||||
# Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral/Mixtral models)
|
||||
#override_base_seq_len:
|
||||
|
||||
# Automatically allocate resources to GPUs (default: True)
|
||||
#gpu_split_auto: True
|
||||
# Automatically allocate resources to GPUs (default: False)
|
||||
# WARNING: Will use more VRAM for single GPU users
|
||||
#gpu_split_auto: False
|
||||
|
||||
# An integer array of GBs of vram to split between GPUs (default: [])
|
||||
#gpu_split: [20.6, 24]
|
||||
|
||||
Reference in New Issue
Block a user