mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-04 20:49:49 +00:00
reduce cast
This commit is contained in:
@@ -646,13 +646,13 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
if args.unet_in_fp8_e5m2:
|
||||
return torch.float8_e5m2
|
||||
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
for candidate in supported_dtypes:
|
||||
if candidate == torch.float16:
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
return candidate
|
||||
if candidate == torch.bfloat16:
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
return candidate
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
Reference in New Issue
Block a user