diff --git a/backend/memory_management.py b/backend/memory_management.py index 050d7ea8..3fa5e5f1 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -1078,15 +1078,18 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma def can_install_bnb(): - if not torch.cuda.is_available(): + try: + if not torch.cuda.is_available(): + return False + + cuda_version = tuple(int(x) for x in torch.version.cuda.split('.')) + + if cuda_version >= (11, 7): + return True + + return False + except: return False - - cuda_version = tuple(int(x) for x in torch.version.cuda.split('.')) - - if cuda_version >= (11, 7): - return True - - return False def soft_empty_cache(force=False):