diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 217f5bf..abe907e 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -177,7 +177,11 @@ class ExllamaV3Container(BaseModelContainer): self.use_tp = True tp_backend = unwrap(kwargs.get("tensor_parallel_backend"), "native") - if not exllama_supports_nccl(): + if tp_backend == "nccl" and not exllama_supports_nccl(): + unsupported_message = ( + "NCCL is not available. Falling back to native backend." + ) + logger.warning(unsupported_message) tp_backend = "native" self.tp_backend = tp_backend diff --git a/backends/exllamav3/utils.py b/backends/exllamav3/utils.py index dbaffb5..0a90487 100644 --- a/backends/exllamav3/utils.py +++ b/backends/exllamav3/utils.py @@ -1,15 +1,13 @@ import platform from loguru import logger - def exllama_supports_nccl(): - if platform.system() != "Windows": + if platform.system() == "Windows": + unsupported_message = ( + "The NCCL tensor parallel backend is not supported on Windows." + ) + logger.warning(unsupported_message) return False - unsupported_message = ( - "The NCCL tensor parallel backend is not supported on Windows. \n" - "Switching to native backend." - ) - logger.warning(unsupported_message) - - return True + import torch + return torch.cuda.is_available() and torch.distributed.is_nccl_available()