From 1419ef29aab7ea8c2132509c860f986ecda1d583 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:10:49 -0700 Subject: [PATCH] Revert "change some dtype behaviors based on community feedbacks" This reverts commit 31bed671ac2c7daf0ecddd0724b78ee5d2abbe16. --- backend/loader.py | 10 +-- backend/memory_management.py | 99 +++++++++++++++++++++------ backend/patcher/controlnet.py | 2 +- modules_forge/supported_controlnet.py | 6 +- 4 files changed, 87 insertions(+), 30 deletions(-) diff --git a/backend/loader.py b/backend/loader.py index 17818085..570b591a 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -107,10 +107,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p model_loader = lambda c: IntegratedFluxTransformer2DModel(**c) unet_config = guess.unet_config.copy() - state_dict_parameters = memory_management.state_dict_parameters(state_dict) + state_dict_size = memory_management.state_dict_size(state_dict) state_dict_dtype = memory_management.state_dict_dtype(state_dict) - storage_dtype = memory_management.unet_dtype(model_params=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes) + storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes) unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype') @@ -140,15 +140,15 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p print(f'Using GGUF state dict: {type_counts}') load_device = memory_management.get_torch_device() - computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes) + computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes) offload_device = memory_management.unet_offload_device() if storage_dtype in ['nf4', 'fp4', 'gguf']: - initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=computation_dtype) + initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype) with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype): model = model_loader(unet_config) else: - initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=storage_dtype) + initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype) need_manual_cast = storage_dtype != computation_dtype to_args = dict(device=initial_device, dtype=storage_dtype) diff --git a/backend/memory_management.py b/backend/memory_management.py index 7b40bd03..5e1c1401 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -301,13 +301,6 @@ def state_dict_size(sd, exclude_device=None): return module_mem -def state_dict_parameters(sd): - module_mem = 0 - for k, v in sd.items(): - module_mem += v.nelement() - return module_mem - - def state_dict_dtype(state_dict): for k, v in state_dict.items(): if hasattr(v, 'is_gguf'): @@ -660,22 +653,44 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor for candidate in supported_dtypes: if candidate == torch.float16: - if should_use_fp16(device, model_params=model_params, prioritize_performance=True, manual_cast=True): + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): return candidate if candidate == torch.bfloat16: - if should_use_bf16(device, model_params=model_params, prioritize_performance=True, manual_cast=True): + if should_use_bf16(device, model_params=model_params, manual_cast=True): return candidate return torch.float32 -def get_computation_dtype(inference_device, parameters=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): +# None means no manual cast +def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): + if weight_dtype == torch.float32: + return None + + fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) + if fp16_supported and weight_dtype == torch.float16: + return None + + bf16_supported = should_use_bf16(inference_device) + if bf16_supported and weight_dtype == torch.bfloat16: + return None + + if fp16_supported and torch.float16 in supported_dtypes: + return torch.float16 + + elif bf16_supported and torch.bfloat16 in supported_dtypes: + return torch.bfloat16 + else: + return torch.float32 + + +def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): for candidate in supported_dtypes: if candidate == torch.float16: - if should_use_fp16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False): + if should_use_fp16(inference_device, prioritize_performance=False): return candidate if candidate == torch.bfloat16: - if should_use_bf16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False): + if should_use_bf16(inference_device): return candidate return torch.float32 @@ -1005,17 +1020,19 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if props.major < 6: return False + fp16_works = False + # FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled + # when the model doesn't actually fit on the card + # TODO: actually test if GP106 and others have the same type of behavior nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"] for x in nvidia_10_series: if x in props.name.lower(): - if manual_cast: - # For storage dtype - free_model_memory = (get_free_memory() * 0.85 - minimum_inference_memory()) - if (not prioritize_performance) or model_params * 4 > free_model_memory: - return True - else: - # For computation dtype - return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow. + fp16_works = True + + if fp16_works or manual_cast: + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True if props.major < 7: return False @@ -1063,7 +1080,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma bf16_works = torch.cuda.is_bf16_supported() if bf16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.85 - minimum_inference_memory()) + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True @@ -1099,3 +1116,43 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) + + +def resolve_lowvram_weight(weight, model, key): # TODO: remove + return weight + + +# TODO: might be cleaner to put this somewhere else +import threading + + +class InterruptProcessingException(Exception): + pass + + +interrupt_processing_mutex = threading.RLock() + +interrupt_processing = False + + +def interrupt_current_processing(value=True): + global interrupt_processing + global interrupt_processing_mutex + with interrupt_processing_mutex: + interrupt_processing = value + + +def processing_interrupted(): + global interrupt_processing + global interrupt_processing_mutex + with interrupt_processing_mutex: + return interrupt_processing + + +def throw_exception_if_processing_interrupted(): + global interrupt_processing + global interrupt_processing_mutex + with interrupt_processing_mutex: + if interrupt_processing: + interrupt_processing = False + raise InterruptProcessingException() diff --git a/backend/patcher/controlnet.py b/backend/patcher/controlnet.py index 095629df..b0372d86 100644 --- a/backend/patcher/controlnet.py +++ b/backend/patcher/controlnet.py @@ -438,7 +438,7 @@ class ControlLora(ControlNet): self.manual_cast_dtype = model.computation_dtype - with using_forge_operations(operations=ControlLoraOps, dtype=dtype, manual_cast_enabled=self.manual_cast_dtype != dtype): + with using_forge_operations(operations=ControlLoraOps, dtype=dtype): self.control_model = cldm.ControlNet(**controlnet_config) self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype) diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 8ca7b768..1d199934 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -110,12 +110,12 @@ class ControlNetPatcher(ControlModelPatcher): controlnet_config['dtype'] = unet_dtype load_device = memory_management.get_torch_device() - computation_dtype = memory_management.get_computation_dtype(load_device) + manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device) controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - with using_forge_operations(dtype=unet_dtype, manual_cast_enabled=computation_dtype != unet_dtype): + with using_forge_operations(dtype=unet_dtype): control_model = cldm.ControlNet(**controlnet_config).to(dtype=unet_dtype) if pth: @@ -139,7 +139,7 @@ class ControlNetPatcher(ControlModelPatcher): # TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=computation_dtype) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return ControlNetPatcher(control) def __init__(self, model_patcher):