mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 02:31:16 +00:00
change some dtype behaviors based on community feedbacks
only influence old devices like 1080/70/60/50. please remove cmd flags if you are on 1080/70/60/50 and previously used many cmd flags to tune performance
This commit is contained in:
@@ -107,10 +107,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||||
|
|
||||||
unet_config = guess.unet_config.copy()
|
unet_config = guess.unet_config.copy()
|
||||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
||||||
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
|
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
|
||||||
|
|
||||||
storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes)
|
storage_dtype = memory_management.unet_dtype(model_params=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
|
||||||
|
|
||||||
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype')
|
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype')
|
||||||
|
|
||||||
@@ -140,15 +140,15 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
|||||||
print(f'Using GGUF state dict: {type_counts}')
|
print(f'Using GGUF state dict: {type_counts}')
|
||||||
|
|
||||||
load_device = memory_management.get_torch_device()
|
load_device = memory_management.get_torch_device()
|
||||||
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes)
|
computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
|
||||||
offload_device = memory_management.unet_offload_device()
|
offload_device = memory_management.unet_offload_device()
|
||||||
|
|
||||||
if storage_dtype in ['nf4', 'fp4', 'gguf']:
|
if storage_dtype in ['nf4', 'fp4', 'gguf']:
|
||||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype)
|
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=computation_dtype)
|
||||||
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
||||||
model = model_loader(unet_config)
|
model = model_loader(unet_config)
|
||||||
else:
|
else:
|
||||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype)
|
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=storage_dtype)
|
||||||
need_manual_cast = storage_dtype != computation_dtype
|
need_manual_cast = storage_dtype != computation_dtype
|
||||||
to_args = dict(device=initial_device, dtype=storage_dtype)
|
to_args = dict(device=initial_device, dtype=storage_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -301,6 +301,13 @@ def state_dict_size(sd, exclude_device=None):
|
|||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
|
|
||||||
|
def state_dict_parameters(sd):
|
||||||
|
module_mem = 0
|
||||||
|
for k, v in sd.items():
|
||||||
|
module_mem += v.nelement()
|
||||||
|
return module_mem
|
||||||
|
|
||||||
|
|
||||||
def state_dict_dtype(state_dict):
|
def state_dict_dtype(state_dict):
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if hasattr(v, 'is_gguf'):
|
if hasattr(v, 'is_gguf'):
|
||||||
@@ -653,44 +660,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
|
|
||||||
for candidate in supported_dtypes:
|
for candidate in supported_dtypes:
|
||||||
if candidate == torch.float16:
|
if candidate == torch.float16:
|
||||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
if should_use_fp16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
|
||||||
return candidate
|
return candidate
|
||||||
if candidate == torch.bfloat16:
|
if candidate == torch.bfloat16:
|
||||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
if should_use_bf16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
|
||||||
return candidate
|
return candidate
|
||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
# None means no manual cast
|
def get_computation_dtype(inference_device, parameters=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
|
||||||
if weight_dtype == torch.float32:
|
|
||||||
return None
|
|
||||||
|
|
||||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
|
||||||
if fp16_supported and weight_dtype == torch.float16:
|
|
||||||
return None
|
|
||||||
|
|
||||||
bf16_supported = should_use_bf16(inference_device)
|
|
||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if fp16_supported and torch.float16 in supported_dtypes:
|
|
||||||
return torch.float16
|
|
||||||
|
|
||||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
|
||||||
for candidate in supported_dtypes:
|
for candidate in supported_dtypes:
|
||||||
if candidate == torch.float16:
|
if candidate == torch.float16:
|
||||||
if should_use_fp16(inference_device, prioritize_performance=False):
|
if should_use_fp16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
|
||||||
return candidate
|
return candidate
|
||||||
if candidate == torch.bfloat16:
|
if candidate == torch.bfloat16:
|
||||||
if should_use_bf16(inference_device):
|
if should_use_bf16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
|
||||||
return candidate
|
return candidate
|
||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
@@ -1020,19 +1005,17 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if props.major < 6:
|
if props.major < 6:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
fp16_works = False
|
|
||||||
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
|
||||||
# when the model doesn't actually fit on the card
|
|
||||||
# TODO: actually test if GP106 and others have the same type of behavior
|
|
||||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||||
for x in nvidia_10_series:
|
for x in nvidia_10_series:
|
||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
fp16_works = True
|
if manual_cast:
|
||||||
|
# For storage dtype
|
||||||
if fp16_works or manual_cast:
|
free_model_memory = (get_free_memory() * 0.85 - minimum_inference_memory())
|
||||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
return True
|
||||||
return True
|
else:
|
||||||
|
# For computation dtype
|
||||||
|
return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow.
|
||||||
|
|
||||||
if props.major < 7:
|
if props.major < 7:
|
||||||
return False
|
return False
|
||||||
@@ -1080,7 +1063,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
bf16_works = torch.cuda.is_bf16_supported()
|
bf16_works = torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
if bf16_works or manual_cast:
|
if bf16_works or manual_cast:
|
||||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
free_model_memory = (get_free_memory() * 0.85 - minimum_inference_memory())
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -1116,43 +1099,3 @@ def soft_empty_cache(force=False):
|
|||||||
|
|
||||||
def unload_all_models():
|
def unload_all_models():
|
||||||
free_memory(1e30, get_torch_device())
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key): # TODO: remove
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: might be cleaner to put this somewhere else
|
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
class InterruptProcessingException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
interrupt_processing_mutex = threading.RLock()
|
|
||||||
|
|
||||||
interrupt_processing = False
|
|
||||||
|
|
||||||
|
|
||||||
def interrupt_current_processing(value=True):
|
|
||||||
global interrupt_processing
|
|
||||||
global interrupt_processing_mutex
|
|
||||||
with interrupt_processing_mutex:
|
|
||||||
interrupt_processing = value
|
|
||||||
|
|
||||||
|
|
||||||
def processing_interrupted():
|
|
||||||
global interrupt_processing
|
|
||||||
global interrupt_processing_mutex
|
|
||||||
with interrupt_processing_mutex:
|
|
||||||
return interrupt_processing
|
|
||||||
|
|
||||||
|
|
||||||
def throw_exception_if_processing_interrupted():
|
|
||||||
global interrupt_processing
|
|
||||||
global interrupt_processing_mutex
|
|
||||||
with interrupt_processing_mutex:
|
|
||||||
if interrupt_processing:
|
|
||||||
interrupt_processing = False
|
|
||||||
raise InterruptProcessingException()
|
|
||||||
|
|||||||
@@ -438,7 +438,7 @@ class ControlLora(ControlNet):
|
|||||||
|
|
||||||
self.manual_cast_dtype = model.computation_dtype
|
self.manual_cast_dtype = model.computation_dtype
|
||||||
|
|
||||||
with using_forge_operations(operations=ControlLoraOps, dtype=dtype):
|
with using_forge_operations(operations=ControlLoraOps, dtype=dtype, manual_cast_enabled=self.manual_cast_dtype != dtype):
|
||||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||||
|
|
||||||
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)
|
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)
|
||||||
|
|||||||
@@ -110,12 +110,12 @@ class ControlNetPatcher(ControlModelPatcher):
|
|||||||
controlnet_config['dtype'] = unet_dtype
|
controlnet_config['dtype'] = unet_dtype
|
||||||
|
|
||||||
load_device = memory_management.get_torch_device()
|
load_device = memory_management.get_torch_device()
|
||||||
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
|
computation_dtype = memory_management.get_computation_dtype(load_device)
|
||||||
|
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
|
|
||||||
with using_forge_operations(dtype=unet_dtype):
|
with using_forge_operations(dtype=unet_dtype, manual_cast_enabled=computation_dtype != unet_dtype):
|
||||||
control_model = cldm.ControlNet(**controlnet_config).to(dtype=unet_dtype)
|
control_model = cldm.ControlNet(**controlnet_config).to(dtype=unet_dtype)
|
||||||
|
|
||||||
if pth:
|
if pth:
|
||||||
@@ -139,7 +139,7 @@ class ControlNetPatcher(ControlModelPatcher):
|
|||||||
# TODO: smarter way of enabling global_average_pooling
|
# TODO: smarter way of enabling global_average_pooling
|
||||||
global_average_pooling = True
|
global_average_pooling = True
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=computation_dtype)
|
||||||
return ControlNetPatcher(control)
|
return ControlNetPatcher(control)
|
||||||
|
|
||||||
def __init__(self, model_patcher):
|
def __init__(self, model_patcher):
|
||||||
|
|||||||
Reference in New Issue
Block a user