diff --git a/backend/memory_management.py b/backend/memory_management.py index 69f84de7..e4cb55ce 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -646,13 +646,13 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if args.unet_in_fp8_e5m2: return torch.float8_e5m2 - if should_use_fp16(device=device, model_params=model_params, manual_cast=True): - if torch.float16 in supported_dtypes: - return torch.float16 - - if should_use_bf16(device, model_params=model_params, manual_cast=True): - if torch.bfloat16 in supported_dtypes: - return torch.bfloat16 + for candidate in supported_dtypes: + if candidate == torch.float16: + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): + return candidate + if candidate == torch.bfloat16: + if should_use_bf16(device, model_params=model_params, manual_cast=True): + return candidate return torch.float32