also implement async offload to control-lora

controlnet, t2iadapters, etc
This commit is contained in:
lllyasviel
2024-02-22 00:20:35 -08:00
parent 638ee43bf1
commit 846fdc3341

View File

@@ -14,6 +14,8 @@ import ldm_patched.modules.ops
import ldm_patched.controlnet.cldm
import ldm_patched.t2ia.adapter
from ldm_patched.modules.ops import main_thread_worker
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
@@ -303,11 +305,12 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias = ldm_patched.modules.ops.cast_bias_weight(self, input)
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module):
def __init__(
@@ -343,11 +346,12 @@ class ControlLoraOps:
def forward(self, input):
weight, bias = ldm_patched.modules.ops.cast_bias_weight(self, input)
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class ControlLora(ControlNet):