Merge upstream PR 14855

This commit is contained in:
lllyasviel
2024-02-21 23:59:40 -08:00
parent 95ddac3117
commit 638ee43bf1
4 changed files with 118 additions and 16 deletions

View File

@@ -6,6 +6,12 @@ import torch
import ldm_patched.modules.model_management
import contextlib
from modules_forge import stream
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files
gc = {}
@contextlib.contextmanager
def use_patched_ops(operations):
@@ -25,12 +31,44 @@ def use_patched_ops(operations):
def cast_bias_weight(s, input):
bias = None
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias
context = contextlib.nullcontext
signal = None
if stream.using_stream:
context = stream.stream_context()
with context(stream.mover_stream):
bias = None
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if stream.using_stream:
signal = stream.mover_stream.record_event()
return weight, bias, signal
@contextlib.contextmanager
def main_thread_worker(weight, bias, signal):
if not stream.using_stream or signal is None:
yield
return
with stream.stream_context()(stream.current_stream):
stream.current_stream.wait_event(signal)
yield
finished_signal = stream.current_stream.record_event()
gc[id(finished_signal)] = (weight, bias, finished_signal)
garbage = []
for k, (w, b, s) in gc.items():
if s.query():
garbage.append(k)
for k in garbage:
del gc[k]
return
class disable_weight_init:
@@ -40,8 +78,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
@@ -55,8 +94,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
@@ -70,8 +110,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
@@ -85,8 +126,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
@@ -101,8 +143,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights: