mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-01 03:31:30 +00:00
Add optimization --cuda-stream
See also the readme for more details
This commit is contained in:
@@ -116,6 +116,7 @@ parser.add_argument("--disable-server-info", action="store_true")
|
||||
parser.add_argument("--multi-user", action="store_true")
|
||||
|
||||
parser.add_argument("--cuda-malloc", action="store_true")
|
||||
parser.add_argument("--cuda-stream", action="store_true")
|
||||
parser.add_argument("--pin-shared-memory", action="store_true")
|
||||
|
||||
if ldm_patched.modules.options.args_parsing:
|
||||
|
||||
@@ -14,7 +14,7 @@ import ldm_patched.modules.ops
|
||||
import ldm_patched.controlnet.cldm
|
||||
import ldm_patched.t2ia.adapter
|
||||
|
||||
from ldm_patched.modules.ops import main_thread_worker
|
||||
from ldm_patched.modules.ops import main_stream_worker
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
@@ -306,7 +306,7 @@ class ControlLoraOps:
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
@@ -347,7 +347,7 @@ class ControlLoraOps:
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
|
||||
@@ -215,9 +215,9 @@ class BaseModel(torch.nn.Module):
|
||||
dtype_size = ldm_patched.modules.model_management.dtype_size(dtype)
|
||||
|
||||
if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention():
|
||||
scaler = 1.25
|
||||
scaler = 1.28
|
||||
else:
|
||||
scaler = 1.75
|
||||
scaler = 1.65
|
||||
if ldm_patched.ldm.modules.attention._ATTN_PRECISION == "fp32":
|
||||
dtype_size = 4
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import time
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from ldm_patched.modules.args_parser import args
|
||||
from modules_forge import stream
|
||||
import ldm_patched.modules.utils
|
||||
import torch
|
||||
import sys
|
||||
@@ -277,6 +278,8 @@ if 'rtx' in torch_device_name.lower():
|
||||
print('Hint: your device supports --pin-shared-memory for potential speed improvements.')
|
||||
if not args.cuda_malloc:
|
||||
print('Hint: your device supports --cuda-malloc for potential speed improvements.')
|
||||
if not args.cuda_stream:
|
||||
print('Hint: your device supports --cuda-stream for potential speed improvements.')
|
||||
|
||||
print("VAE dtype:", VAE_DTYPE)
|
||||
|
||||
@@ -326,7 +329,8 @@ class LoadedModel:
|
||||
raise e
|
||||
|
||||
if not disable_async_load:
|
||||
print("[Memory Management] Requested Async Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
|
||||
flag = 'ASYNC' if stream.using_stream else 'SYNC'
|
||||
print(f"[Memory Management] Requested {flag} Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
|
||||
real_async_memory = 0
|
||||
mem_counter = 0
|
||||
for m in self.real_model.modules():
|
||||
@@ -345,9 +349,9 @@ class LoadedModel:
|
||||
elif hasattr(m, "weight"):
|
||||
m.to(self.device)
|
||||
mem_counter += module_size(m)
|
||||
print("[Memory Management] Async Loader Disabled for ", m)
|
||||
print("[Async Memory Management] Parameters Loaded to Async Stream (MB) = ", real_async_memory / (1024 * 1024))
|
||||
print("[Async Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
|
||||
print(f"[Memory Management] {flag} Loader Disabled for ", m)
|
||||
print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024))
|
||||
print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
|
||||
|
||||
self.model_accelerated = True
|
||||
|
||||
@@ -372,7 +376,7 @@ class LoadedModel:
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model and self.memory_required == other.memory_required
|
||||
return self.model is other.model # and self.memory_required == other.memory_required
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024)
|
||||
@@ -383,7 +387,8 @@ def unload_model_clones(model):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
print(f"Reuse {len(to_unload)} loaded models")
|
||||
if len(to_unload) > 0:
|
||||
print(f"Reuse {len(to_unload)} loaded models")
|
||||
|
||||
for i in to_unload:
|
||||
current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
|
||||
@@ -414,9 +419,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
global vram_state
|
||||
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required)
|
||||
extra_mem = max(minimum_inference_memory(), memory_required)
|
||||
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
@@ -439,7 +442,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
free_memory(extra_mem, d, models_already_loaded)
|
||||
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
if moving_time > 0.01:
|
||||
if moving_time > 0.1:
|
||||
print(f'Memory cleanup has taken {moving_time:.2f} seconds')
|
||||
|
||||
return
|
||||
@@ -466,25 +469,25 @@ def load_models_gpu(models, memory_required=0):
|
||||
async_kept_memory = -1
|
||||
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
model_memory = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
estimated_memory_remaining = current_free_mem - model_size - extra_mem
|
||||
minimal_inference_memory = minimum_inference_memory()
|
||||
estimated_remaining_memory = current_free_mem - model_memory - minimal_inference_memory
|
||||
|
||||
print("[Memory Management] Current Free Memory (MB) = ", current_free_mem / (1024 * 1024))
|
||||
print("[Memory Management] Model Memory (MB) = ", model_size / (1024 * 1024))
|
||||
print("[Memory Management] Estimated Inference Memory (MB) = ", extra_mem / (1024 * 1024))
|
||||
print("[Memory Management] Estimated Remaining Memory (MB) = ", estimated_memory_remaining / (1024 * 1024))
|
||||
print("[Memory Management] Current Free GPU Memory (MB) = ", current_free_mem / (1024 * 1024))
|
||||
print("[Memory Management] Model Memory (MB) = ", model_memory / (1024 * 1024))
|
||||
print("[Memory Management] Minimal Inference Memory (MB) = ", minimal_inference_memory / (1024 * 1024))
|
||||
print("[Memory Management] Estimated Remaining GPU Memory (MB) = ", estimated_remaining_memory / (1024 * 1024))
|
||||
|
||||
if estimated_memory_remaining < 0:
|
||||
if estimated_remaining_memory < 0:
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
async_overhead_memory = 1024 * 1024 * 1024
|
||||
async_kept_memory = current_free_mem - extra_mem - async_overhead_memory
|
||||
async_kept_memory = (current_free_mem - minimal_inference_memory) / 1.3
|
||||
async_kept_memory = int(max(0, async_kept_memory))
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
async_kept_memory = 0
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(async_kept_memory)
|
||||
loaded_model.model_load(async_kept_memory)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
|
||||
@@ -10,7 +10,7 @@ from modules_forge import stream
|
||||
|
||||
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files
|
||||
gc = {}
|
||||
stash = {}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -31,26 +31,25 @@ def use_patched_ops(operations):
|
||||
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
context = contextlib.nullcontext
|
||||
signal = None
|
||||
weight, bias, signal = None, None, None
|
||||
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
|
||||
|
||||
if stream.using_stream:
|
||||
context = stream.stream_context()
|
||||
|
||||
with context(stream.mover_stream):
|
||||
bias = None
|
||||
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
|
||||
with stream.stream_context()(stream.mover_stream):
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
signal = stream.mover_stream.record_event()
|
||||
else:
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
|
||||
if stream.using_stream:
|
||||
signal = stream.mover_stream.record_event()
|
||||
return weight, bias, signal
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def main_thread_worker(weight, bias, signal):
|
||||
def main_stream_worker(weight, bias, signal):
|
||||
if not stream.using_stream or signal is None:
|
||||
yield
|
||||
return
|
||||
@@ -59,40 +58,25 @@ def main_thread_worker(weight, bias, signal):
|
||||
stream.current_stream.wait_event(signal)
|
||||
yield
|
||||
finished_signal = stream.current_stream.record_event()
|
||||
size = weight.element_size() * weight.nelement()
|
||||
if bias is not None:
|
||||
size += bias.element_size() * bias.nelement()
|
||||
gc[id(finished_signal)] = (weight, bias, finished_signal, size)
|
||||
|
||||
overhead = sum([l for k, (w, b, s, l) in gc.items()])
|
||||
|
||||
if overhead > 512 * 1024 * 1024:
|
||||
stream.mover_stream.synchronize()
|
||||
stream.current_stream.synchronize()
|
||||
stash[id(finished_signal)] = (weight, bias, finished_signal)
|
||||
|
||||
garbage = []
|
||||
for k, (w, b, s, l) in gc.items():
|
||||
for k, (w, b, s) in stash.items():
|
||||
if s.query():
|
||||
garbage.append(k)
|
||||
|
||||
for k in garbage:
|
||||
del gc[k]
|
||||
del stash[k]
|
||||
return
|
||||
|
||||
|
||||
def cleanup_cache():
|
||||
global gc
|
||||
if not stream.using_stream:
|
||||
return
|
||||
|
||||
if stream.current_stream is not None:
|
||||
with stream.stream_context()(stream.current_stream):
|
||||
for k, (w, b, s, l) in gc.items():
|
||||
stream.current_stream.wait_event(s)
|
||||
stream.current_stream.synchronize()
|
||||
|
||||
gc.clear()
|
||||
|
||||
if stream.mover_stream is not None:
|
||||
stream.mover_stream.synchronize()
|
||||
stream.current_stream.synchronize()
|
||||
stream.mover_stream.synchronize()
|
||||
stash.clear()
|
||||
return
|
||||
|
||||
|
||||
@@ -104,7 +88,7 @@ class disable_weight_init:
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -120,7 +104,7 @@ class disable_weight_init:
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -136,7 +120,7 @@ class disable_weight_init:
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -152,7 +136,7 @@ class disable_weight_init:
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -169,7 +153,7 @@ class disable_weight_init:
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user