diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index f1749fe..8653132 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -109,14 +109,19 @@ class ExllamaV2Container: # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() if gpu_count > 1: - self.gpu_split = kwargs.get("gpu_split") + gpu_split = kwargs.get("gpu_split") - # Auto GPU split parameters - self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) - autosplit_reserve_megabytes = unwrap(kwargs.get("autosplit_reserve"), [96]) - self.autosplit_reserve = list( - map(lambda value: value * 1024**2, autosplit_reserve_megabytes) - ) + if gpu_split: + self.gpu_split = gpu_split + else: + # Auto GPU split parameters + self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) + autosplit_reserve_megabytes = unwrap( + kwargs.get("autosplit_reserve"), [96] + ) + self.autosplit_reserve = list( + map(lambda value: value * 1024**2, autosplit_reserve_megabytes) + ) else: self.gpu_split_auto = False logger.info("Disabling GPU split because one GPU is in use.")