Add optimization --cuda-stream

See also the readme for more details
This commit is contained in:
lllyasviel
2024-02-24 14:00:48 -08:00
committed by GitHub
parent 0f09d98814
commit 434ca2169f
9 changed files with 63 additions and 73 deletions

View File

@@ -116,6 +116,7 @@ parser.add_argument("--disable-server-info", action="store_true")
parser.add_argument("--multi-user", action="store_true")
parser.add_argument("--cuda-malloc", action="store_true")
parser.add_argument("--cuda-stream", action="store_true")
parser.add_argument("--pin-shared-memory", action="store_true")
if ldm_patched.modules.options.args_parsing:

View File

@@ -14,7 +14,7 @@ import ldm_patched.modules.ops
import ldm_patched.controlnet.cldm
import ldm_patched.t2ia.adapter
from ldm_patched.modules.ops import main_thread_worker
from ldm_patched.modules.ops import main_stream_worker
def broadcast_image_to(tensor, target_batch_size, batched_number):
@@ -306,7 +306,7 @@ class ControlLoraOps:
def forward(self, input):
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
@@ -347,7 +347,7 @@ class ControlLoraOps:
def forward(self, input):
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:

View File

@@ -215,9 +215,9 @@ class BaseModel(torch.nn.Module):
dtype_size = ldm_patched.modules.model_management.dtype_size(dtype)
if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention():
scaler = 1.25
scaler = 1.28
else:
scaler = 1.75
scaler = 1.65
if ldm_patched.ldm.modules.attention._ATTN_PRECISION == "fp32":
dtype_size = 4

View File

@@ -6,6 +6,7 @@ import time
import psutil
from enum import Enum
from ldm_patched.modules.args_parser import args
from modules_forge import stream
import ldm_patched.modules.utils
import torch
import sys
@@ -277,6 +278,8 @@ if 'rtx' in torch_device_name.lower():
print('Hint: your device supports --pin-shared-memory for potential speed improvements.')
if not args.cuda_malloc:
print('Hint: your device supports --cuda-malloc for potential speed improvements.')
if not args.cuda_stream:
print('Hint: your device supports --cuda-stream for potential speed improvements.')
print("VAE dtype:", VAE_DTYPE)
@@ -326,7 +329,8 @@ class LoadedModel:
raise e
if not disable_async_load:
print("[Memory Management] Requested Async Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
flag = 'ASYNC' if stream.using_stream else 'SYNC'
print(f"[Memory Management] Requested {flag} Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
real_async_memory = 0
mem_counter = 0
for m in self.real_model.modules():
@@ -345,9 +349,9 @@ class LoadedModel:
elif hasattr(m, "weight"):
m.to(self.device)
mem_counter += module_size(m)
print("[Memory Management] Async Loader Disabled for ", m)
print("[Async Memory Management] Parameters Loaded to Async Stream (MB) = ", real_async_memory / (1024 * 1024))
print("[Async Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
print(f"[Memory Management] {flag} Loader Disabled for ", m)
print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024))
print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
self.model_accelerated = True
@@ -372,7 +376,7 @@ class LoadedModel:
self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other):
return self.model is other.model and self.memory_required == other.memory_required
return self.model is other.model # and self.memory_required == other.memory_required
def minimum_inference_memory():
return (1024 * 1024 * 1024)
@@ -383,7 +387,8 @@ def unload_model_clones(model):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
print(f"Reuse {len(to_unload)} loaded models")
if len(to_unload) > 0:
print(f"Reuse {len(to_unload)} loaded models")
for i in to_unload:
current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
@@ -414,9 +419,7 @@ def load_models_gpu(models, memory_required=0):
global vram_state
execution_start_time = time.perf_counter()
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
extra_mem = max(minimum_inference_memory(), memory_required)
models_to_load = []
models_already_loaded = []
@@ -439,7 +442,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.01:
if moving_time > 0.1:
print(f'Memory cleanup has taken {moving_time:.2f} seconds')
return
@@ -466,25 +469,25 @@ def load_models_gpu(models, memory_required=0):
async_kept_memory = -1
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
model_memory = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
estimated_memory_remaining = current_free_mem - model_size - extra_mem
minimal_inference_memory = minimum_inference_memory()
estimated_remaining_memory = current_free_mem - model_memory - minimal_inference_memory
print("[Memory Management] Current Free Memory (MB) = ", current_free_mem / (1024 * 1024))
print("[Memory Management] Model Memory (MB) = ", model_size / (1024 * 1024))
print("[Memory Management] Estimated Inference Memory (MB) = ", extra_mem / (1024 * 1024))
print("[Memory Management] Estimated Remaining Memory (MB) = ", estimated_memory_remaining / (1024 * 1024))
print("[Memory Management] Current Free GPU Memory (MB) = ", current_free_mem / (1024 * 1024))
print("[Memory Management] Model Memory (MB) = ", model_memory / (1024 * 1024))
print("[Memory Management] Minimal Inference Memory (MB) = ", minimal_inference_memory / (1024 * 1024))
print("[Memory Management] Estimated Remaining GPU Memory (MB) = ", estimated_remaining_memory / (1024 * 1024))
if estimated_memory_remaining < 0:
if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM
async_overhead_memory = 1024 * 1024 * 1024
async_kept_memory = current_free_mem - extra_mem - async_overhead_memory
async_kept_memory = (current_free_mem - minimal_inference_memory) / 1.3
async_kept_memory = int(max(0, async_kept_memory))
if vram_set_state == VRAMState.NO_VRAM:
async_kept_memory = 0
cur_loaded_model = loaded_model.model_load(async_kept_memory)
loaded_model.model_load(async_kept_memory)
current_loaded_models.insert(0, loaded_model)
moving_time = time.perf_counter() - execution_start_time

View File

@@ -10,7 +10,7 @@ from modules_forge import stream
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files
gc = {}
stash = {}
@contextlib.contextmanager
@@ -31,26 +31,25 @@ def use_patched_ops(operations):
def cast_bias_weight(s, input):
context = contextlib.nullcontext
signal = None
weight, bias, signal = None, None, None
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
if stream.using_stream:
context = stream.stream_context()
with context(stream.mover_stream):
bias = None
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
with stream.stream_context()(stream.mover_stream):
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
signal = stream.mover_stream.record_event()
else:
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if stream.using_stream:
signal = stream.mover_stream.record_event()
return weight, bias, signal
@contextlib.contextmanager
def main_thread_worker(weight, bias, signal):
def main_stream_worker(weight, bias, signal):
if not stream.using_stream or signal is None:
yield
return
@@ -59,40 +58,25 @@ def main_thread_worker(weight, bias, signal):
stream.current_stream.wait_event(signal)
yield
finished_signal = stream.current_stream.record_event()
size = weight.element_size() * weight.nelement()
if bias is not None:
size += bias.element_size() * bias.nelement()
gc[id(finished_signal)] = (weight, bias, finished_signal, size)
overhead = sum([l for k, (w, b, s, l) in gc.items()])
if overhead > 512 * 1024 * 1024:
stream.mover_stream.synchronize()
stream.current_stream.synchronize()
stash[id(finished_signal)] = (weight, bias, finished_signal)
garbage = []
for k, (w, b, s, l) in gc.items():
for k, (w, b, s) in stash.items():
if s.query():
garbage.append(k)
for k in garbage:
del gc[k]
del stash[k]
return
def cleanup_cache():
global gc
if not stream.using_stream:
return
if stream.current_stream is not None:
with stream.stream_context()(stream.current_stream):
for k, (w, b, s, l) in gc.items():
stream.current_stream.wait_event(s)
stream.current_stream.synchronize()
gc.clear()
if stream.mover_stream is not None:
stream.mover_stream.synchronize()
stream.current_stream.synchronize()
stream.mover_stream.synchronize()
stash.clear()
return
@@ -104,7 +88,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -120,7 +104,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -136,7 +120,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -152,7 +136,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
@@ -169,7 +153,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):