diff --git a/backend/memory_management.py b/backend/memory_management.py index ea64d1b7..29e5239c 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -619,6 +619,18 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo return torch.float32 +def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): + for candidate in supported_dtypes: + if candidate == torch.float16: + if should_use_fp16(inference_device, prioritize_performance=False): + return candidate + if candidate == torch.bfloat16: + if should_use_bf16(inference_device): + return candidate + + return torch.float32 + + def text_encoder_offload_device(): if args.always_gpu: return get_torch_device() diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py index e824eddd..49a84a95 100644 --- a/backend/patcher/unet.py +++ b/backend/patcher/unet.py @@ -13,10 +13,9 @@ class UnetPatcher(ModelPatcher): unet_dtype = memory_management.unet_dtype(model_params=parameters) load_device = memory_management.get_torch_device() initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype) - manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device, supported_dtypes=config.supported_inference_dtypes) - manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype + computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=config.supported_inference_dtypes) model.to(device=initial_load_device, dtype=unet_dtype) - model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype) + model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=computation_dtype) return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device) def __init__(self, *args, **kwargs):