From 846fdc3341ae75160d96871bff2eeaffd25d638a Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 22 Feb 2024 00:20:35 -0800 Subject: [PATCH] also implement async offload to control-lora controlnet, t2iadapters, etc --- ldm_patched/modules/controlnet.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 94d01cdf..296941b5 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -14,6 +14,8 @@ import ldm_patched.modules.ops import ldm_patched.controlnet.cldm import ldm_patched.t2ia.adapter +from ldm_patched.modules.ops import main_thread_worker + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -303,11 +305,12 @@ class ControlLoraOps: self.bias = None def forward(self, input): - weight, bias = ldm_patched.modules.ops.cast_bias_weight(self, input) - if self.up is not None: - return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) - else: - return torch.nn.functional.linear(input, weight, bias) + weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input) + with main_thread_worker(weight, bias, signal): + if self.up is not None: + return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) + else: + return torch.nn.functional.linear(input, weight, bias) class Conv2d(torch.nn.Module): def __init__( @@ -343,11 +346,12 @@ class ControlLoraOps: def forward(self, input): - weight, bias = ldm_patched.modules.ops.cast_bias_weight(self, input) - if self.up is not None: - return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) - else: - return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input) + with main_thread_worker(weight, bias, signal): + if self.up is not None: + return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) + else: + return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) class ControlLora(ControlNet):