mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-03 03:59:50 +00:00
also implement async offload to control-lora
controlnet, t2iadapters, etc
This commit is contained in:
@@ -14,6 +14,8 @@ import ldm_patched.modules.ops
|
||||
import ldm_patched.controlnet.cldm
|
||||
import ldm_patched.t2ia.adapter
|
||||
|
||||
from ldm_patched.modules.ops import main_thread_worker
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
@@ -303,11 +305,12 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = ldm_patched.modules.ops.cast_bias_weight(self, input)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -343,11 +346,12 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = ldm_patched.modules.ops.cast_bias_weight(self, input)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
|
||||
Reference in New Issue
Block a user