From 638ee43bf1cc9fe894435ca6d7830645a0046339 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 21 Feb 2024 23:59:40 -0800 Subject: [PATCH] Merge upstream PR 14855 --- ldm_patched/modules/ops.py | 75 ++++++++++++++++++++++++++-------- modules/shared_options.py | 1 + modules_forge/forge_sampler.py | 2 + modules_forge/stream.py | 56 +++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 16 deletions(-) create mode 100644 modules_forge/stream.py diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index 57eab3a8..fdbaa064 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -6,6 +6,12 @@ import torch import ldm_patched.modules.model_management import contextlib +from modules_forge import stream + + +# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files +gc = {} + @contextlib.contextmanager def use_patched_ops(operations): @@ -25,12 +31,44 @@ def use_patched_ops(operations): def cast_bias_weight(s, input): - bias = None - non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device) - if s.bias is not None: - bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) - weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) - return weight, bias + context = contextlib.nullcontext + signal = None + + if stream.using_stream: + context = stream.stream_context() + + with context(stream.mover_stream): + bias = None + non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device) + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + + if stream.using_stream: + signal = stream.mover_stream.record_event() + return weight, bias, signal + + +@contextlib.contextmanager +def main_thread_worker(weight, bias, signal): + if not stream.using_stream or signal is None: + yield + return + + with stream.stream_context()(stream.current_stream): + stream.current_stream.wait_event(signal) + yield + finished_signal = stream.current_stream.record_event() + gc[id(finished_signal)] = (weight, bias, finished_signal) + + garbage = [] + for k, (w, b, s) in gc.items(): + if s.query(): + garbage.append(k) + + for k in garbage: + del gc[k] + return class disable_weight_init: @@ -40,8 +78,9 @@ class disable_weight_init: return None def forward_ldm_patched_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, signal = cast_bias_weight(self, input) + with main_thread_worker(weight, bias, signal): + return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): if self.ldm_patched_cast_weights: @@ -55,8 +94,9 @@ class disable_weight_init: return None def forward_ldm_patched_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, signal = cast_bias_weight(self, input) + with main_thread_worker(weight, bias, signal): + return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): if self.ldm_patched_cast_weights: @@ -70,8 +110,9 @@ class disable_weight_init: return None def forward_ldm_patched_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, signal = cast_bias_weight(self, input) + with main_thread_worker(weight, bias, signal): + return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): if self.ldm_patched_cast_weights: @@ -85,8 +126,9 @@ class disable_weight_init: return None def forward_ldm_patched_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + weight, bias, signal = cast_bias_weight(self, input) + with main_thread_worker(weight, bias, signal): + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): if self.ldm_patched_cast_weights: @@ -101,8 +143,9 @@ class disable_weight_init: return None def forward_ldm_patched_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + weight, bias, signal = cast_bias_weight(self, input) + with main_thread_worker(weight, bias, signal): + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): if self.ldm_patched_cast_weights: diff --git a/modules/shared_options.py b/modules/shared_options.py index 04120c6a..504f1b24 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -216,6 +216,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), + "use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("(Requires restart in Forge.) Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."), })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index bac7d156..de5fb72c 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -2,6 +2,7 @@ import torch from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn from ldm_patched.modules.samplers import sampling_function from ldm_patched.modules import model_management +from modules_forge.stream import synchronize_current_stream def cond_from_a1111_to_patched_ldm(cond): @@ -113,4 +114,5 @@ def sampling_prepare(unet, x): def sampling_cleanup(unet): for cnet in unet.list_controlnets(): cnet.cleanup() + synchronize_current_stream() return diff --git a/modules_forge/stream.py b/modules_forge/stream.py new file mode 100644 index 00000000..ccb9784c --- /dev/null +++ b/modules_forge/stream.py @@ -0,0 +1,56 @@ +# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855 + +import torch + +from modules import shared +from ldm_patched.modules import model_management + + +def stream_context(): + if torch.cuda.is_available(): + return torch.cuda.stream + + if model_management.is_intel_xpu(): + return torch.xpu.stream + + return None + + +def get_current_stream(): + try: + if torch.cuda.is_available(): + return torch.cuda.current_stream(torch.device(torch.cuda.current_device())) + if model_management.is_intel_xpu(): + return torch.xpu.current_stream(torch.device("xpu")) + except: + pass + print('Stream is not used.') + return None + + +def get_new_stream(): + try: + if torch.cuda.is_available(): + return torch.cuda.Stream(torch.device(torch.cuda.current_device())) + if model_management.is_intel_xpu(): + return torch.xpu.Stream(torch.device("xpu")) + except: + pass + print('Stream is not used.') + return None + + +def synchronize_current_stream(): + global current_stream + if current_stream is not None: + current_stream.synchronize() + + +if shared.opts.use_non_streamlined_lowvram: + current_stream = None + mover_stream = None + using_stream = False +else: + current_stream = get_current_stream() + mover_stream = get_new_stream() + using_stream = current_stream is not None and mover_stream is not None