reduce cast

This commit is contained in:
layerdiffusion
2024-08-16 03:59:59 -07:00
parent dc7f92eb96
commit 7c0f78e424

View File

@@ -646,13 +646,13 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if args.unet_in_fp8_e5m2:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
for candidate in supported_dtypes:
if candidate == torch.float16:
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
return candidate
if candidate == torch.bfloat16:
if should_use_bf16(device, model_params=model_params, manual_cast=True):
return candidate
return torch.float32