From 7c0f78e4248d4148a66c1eb5800d166d95378b84 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 16 Aug 2024 03:59:59 -0700 Subject: [PATCH] reduce cast --- backend/memory_management.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/memory_management.py b/backend/memory_management.py index 69f84de7..e4cb55ce 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -646,13 +646,13 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if args.unet_in_fp8_e5m2: return torch.float8_e5m2 - if should_use_fp16(device=device, model_params=model_params, manual_cast=True): - if torch.float16 in supported_dtypes: - return torch.float16 - - if should_use_bf16(device, model_params=model_params, manual_cast=True): - if torch.bfloat16 in supported_dtypes: - return torch.bfloat16 + for candidate in supported_dtypes: + if candidate == torch.float16: + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): + return candidate + if candidate == torch.bfloat16: + if should_use_bf16(device, model_params=model_params, manual_cast=True): + return candidate return torch.float32