revise kernel

This commit is contained in:
layerdiffusion
2024-08-07 17:24:22 -07:00
parent b61bf553ea
commit e1df7a1bae
2 changed files with 14 additions and 3 deletions

View File

@@ -619,6 +619,18 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
return torch.float32
def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
for candidate in supported_dtypes:
if candidate == torch.float16:
if should_use_fp16(inference_device, prioritize_performance=False):
return candidate
if candidate == torch.bfloat16:
if should_use_bf16(inference_device):
return candidate
return torch.float32
def text_encoder_offload_device():
if args.always_gpu:
return get_torch_device()

View File

@@ -13,10 +13,9 @@ class UnetPatcher(ModelPatcher):
unet_dtype = memory_management.unet_dtype(model_params=parameters)
load_device = memory_management.get_torch_device()
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device, supported_dtypes=config.supported_inference_dtypes)
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=config.supported_inference_dtypes)
model.to(device=initial_load_device, dtype=unet_dtype)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=computation_dtype)
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
def __init__(self, *args, **kwargs):