change some dtype behaviors based on community feedbacks

only influence old devices like 1080/70/60/50.
please remove cmd flags if you are on 1080/70/60/50 and previously used many cmd flags to tune performance
This commit is contained in:
layerdiffusion
2024-08-21 08:46:52 -07:00
parent 2b1e7366a7
commit 31bed671ac
4 changed files with 30 additions and 87 deletions

View File

@@ -110,12 +110,12 @@ class ControlNetPatcher(ControlModelPatcher):
controlnet_config['dtype'] = unet_dtype
load_device = memory_management.get_torch_device()
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
computation_dtype = memory_management.get_computation_dtype(load_device)
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
with using_forge_operations(dtype=unet_dtype):
with using_forge_operations(dtype=unet_dtype, manual_cast_enabled=computation_dtype != unet_dtype):
control_model = cldm.ControlNet(**controlnet_config).to(dtype=unet_dtype)
if pth:
@@ -139,7 +139,7 @@ class ControlNetPatcher(ControlModelPatcher):
# TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=computation_dtype)
return ControlNetPatcher(control)
def __init__(self, model_patcher):