From 4e3c78178a7eae73311fb269cdeaf7a0b40644d1 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:23:38 -0700 Subject: [PATCH] [revised] change some dtype behaviors based on community feedbacks only influence old devices like 1080/70/60/50. please remove cmd flags if you are on 1080/70/60/50 and previously used many cmd flags to tune performance --- backend/loader.py | 10 +-- backend/memory_management.py | 111 +++++++------------------- backend/patcher/controlnet.py | 2 +- modules_forge/supported_controlnet.py | 6 +- 4 files changed, 37 insertions(+), 92 deletions(-) diff --git a/backend/loader.py b/backend/loader.py index 570b591a..17818085 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -107,10 +107,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p model_loader = lambda c: IntegratedFluxTransformer2DModel(**c) unet_config = guess.unet_config.copy() - state_dict_size = memory_management.state_dict_size(state_dict) + state_dict_parameters = memory_management.state_dict_parameters(state_dict) state_dict_dtype = memory_management.state_dict_dtype(state_dict) - storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes) + storage_dtype = memory_management.unet_dtype(model_params=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes) unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype') @@ -140,15 +140,15 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p print(f'Using GGUF state dict: {type_counts}') load_device = memory_management.get_torch_device() - computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes) + computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes) offload_device = memory_management.unet_offload_device() if storage_dtype in ['nf4', 'fp4', 'gguf']: - initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype) + initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=computation_dtype) with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype): model = model_loader(unet_config) else: - initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype) + initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=storage_dtype) need_manual_cast = storage_dtype != computation_dtype to_args = dict(device=initial_device, dtype=storage_dtype) diff --git a/backend/memory_management.py b/backend/memory_management.py index 5e1c1401..1d6e9157 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -301,6 +301,13 @@ def state_dict_size(sd, exclude_device=None): return module_mem +def state_dict_parameters(sd): + module_mem = 0 + for k, v in sd.items(): + module_mem += v.nelement() + return module_mem + + def state_dict_dtype(state_dict): for k, v in state_dict.items(): if hasattr(v, 'is_gguf'): @@ -653,44 +660,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor for candidate in supported_dtypes: if candidate == torch.float16: - if should_use_fp16(device=device, model_params=model_params, manual_cast=True): + if should_use_fp16(device, model_params=model_params, prioritize_performance=True, manual_cast=True): return candidate if candidate == torch.bfloat16: - if should_use_bf16(device, model_params=model_params, manual_cast=True): + if should_use_bf16(device, model_params=model_params, prioritize_performance=True, manual_cast=True): return candidate return torch.float32 -# None means no manual cast -def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): - if weight_dtype == torch.float32: - return None - - fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) - if fp16_supported and weight_dtype == torch.float16: - return None - - bf16_supported = should_use_bf16(inference_device) - if bf16_supported and weight_dtype == torch.bfloat16: - return None - - if fp16_supported and torch.float16 in supported_dtypes: - return torch.float16 - - elif bf16_supported and torch.bfloat16 in supported_dtypes: - return torch.bfloat16 - else: - return torch.float32 - - -def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): +def get_computation_dtype(inference_device, parameters=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): for candidate in supported_dtypes: if candidate == torch.float16: - if should_use_fp16(inference_device, prioritize_performance=False): + if should_use_fp16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False): return candidate if candidate == torch.bfloat16: - if should_use_bf16(inference_device): + if should_use_bf16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False): return candidate return torch.float32 @@ -1020,19 +1005,17 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if props.major < 6: return False - fp16_works = False - # FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled - # when the model doesn't actually fit on the card - # TODO: actually test if GP106 and others have the same type of behavior nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"] for x in nvidia_10_series: if x in props.name.lower(): - fp16_works = True - - if fp16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) - if (not prioritize_performance) or model_params * 4 > free_model_memory: - return True + if manual_cast: + # For storage dtype + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True + else: + # For computation dtype + return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow. if props.major < 7: return False @@ -1077,12 +1060,14 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if props.major >= 8: return True - bf16_works = torch.cuda.is_bf16_supported() - - if bf16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) - if (not prioritize_performance) or model_params * 4 > free_model_memory: - return True + if torch.cuda.is_bf16_supported(): + # This device is an old enough device but bf16 somewhat reports supported. + # So in this case bf16 should only be used as storge dtype + if manual_cast: + # For storage dtype + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True return False @@ -1116,43 +1101,3 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) - - -def resolve_lowvram_weight(weight, model, key): # TODO: remove - return weight - - -# TODO: might be cleaner to put this somewhere else -import threading - - -class InterruptProcessingException(Exception): - pass - - -interrupt_processing_mutex = threading.RLock() - -interrupt_processing = False - - -def interrupt_current_processing(value=True): - global interrupt_processing - global interrupt_processing_mutex - with interrupt_processing_mutex: - interrupt_processing = value - - -def processing_interrupted(): - global interrupt_processing - global interrupt_processing_mutex - with interrupt_processing_mutex: - return interrupt_processing - - -def throw_exception_if_processing_interrupted(): - global interrupt_processing - global interrupt_processing_mutex - with interrupt_processing_mutex: - if interrupt_processing: - interrupt_processing = False - raise InterruptProcessingException() diff --git a/backend/patcher/controlnet.py b/backend/patcher/controlnet.py index b0372d86..095629df 100644 --- a/backend/patcher/controlnet.py +++ b/backend/patcher/controlnet.py @@ -438,7 +438,7 @@ class ControlLora(ControlNet): self.manual_cast_dtype = model.computation_dtype - with using_forge_operations(operations=ControlLoraOps, dtype=dtype): + with using_forge_operations(operations=ControlLoraOps, dtype=dtype, manual_cast_enabled=self.manual_cast_dtype != dtype): self.control_model = cldm.ControlNet(**controlnet_config) self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype) diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 1d199934..8ca7b768 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -110,12 +110,12 @@ class ControlNetPatcher(ControlModelPatcher): controlnet_config['dtype'] = unet_dtype load_device = memory_management.get_torch_device() - manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device) + computation_dtype = memory_management.get_computation_dtype(load_device) controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - with using_forge_operations(dtype=unet_dtype): + with using_forge_operations(dtype=unet_dtype, manual_cast_enabled=computation_dtype != unet_dtype): control_model = cldm.ControlNet(**controlnet_config).to(dtype=unet_dtype) if pth: @@ -139,7 +139,7 @@ class ControlNetPatcher(ControlModelPatcher): # TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=computation_dtype) return ControlNetPatcher(control) def __init__(self, model_patcher):