mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
Merge upstream PR 14855
This commit is contained in:
@@ -6,6 +6,12 @@ import torch
|
||||
import ldm_patched.modules.model_management
|
||||
import contextlib
|
||||
|
||||
from modules_forge import stream
|
||||
|
||||
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files
|
||||
gc = {}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_patched_ops(operations):
|
||||
@@ -25,12 +31,44 @@ def use_patched_ops(operations):
|
||||
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
return weight, bias
|
||||
context = contextlib.nullcontext
|
||||
signal = None
|
||||
|
||||
if stream.using_stream:
|
||||
context = stream.stream_context()
|
||||
|
||||
with context(stream.mover_stream):
|
||||
bias = None
|
||||
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
|
||||
if stream.using_stream:
|
||||
signal = stream.mover_stream.record_event()
|
||||
return weight, bias, signal
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def main_thread_worker(weight, bias, signal):
|
||||
if not stream.using_stream or signal is None:
|
||||
yield
|
||||
return
|
||||
|
||||
with stream.stream_context()(stream.current_stream):
|
||||
stream.current_stream.wait_event(signal)
|
||||
yield
|
||||
finished_signal = stream.current_stream.record_event()
|
||||
gc[id(finished_signal)] = (weight, bias, finished_signal)
|
||||
|
||||
garbage = []
|
||||
for k, (w, b, s) in gc.items():
|
||||
if s.query():
|
||||
garbage.append(k)
|
||||
|
||||
for k in garbage:
|
||||
del gc[k]
|
||||
return
|
||||
|
||||
|
||||
class disable_weight_init:
|
||||
@@ -40,8 +78,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
@@ -55,8 +94,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
@@ -70,8 +110,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
@@ -85,8 +126,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
@@ -101,8 +143,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
|
||||
Reference in New Issue
Block a user