Allow short and long caption combinations like form the new captioning system. Merge the network into the model before inference and reextract when done. Doubles inference speed on locon models during inference. allow splitting a batch into individual components and run them through alone. Basicallt gradient accumulation with single batch size.

This commit is contained in:
Jaret Burkett
2023-10-24 16:02:07 -06:00
parent 73c8b50975
commit 002279cec3
9 changed files with 315 additions and 115 deletions

View File

@@ -103,7 +103,21 @@ class ToolkitModuleMixin:
# this may get an additional positional arg or not
def forward(self: Module, x, *args, **kwargs):
if not self.network_ref().is_active:
skip = False
network = self.network_ref()
# skip if not active
if not network.is_active:
skip = True
# skip if is merged in
if network.is_merged_in:
skip = True
# skip if multiplier is 0
if network._multiplier == 0:
skip = True
if skip:
# network is not active, avoid doing anything
return self.org_forward(x, *args, **kwargs)
@@ -191,6 +205,52 @@ class ToolkitModuleMixin:
# reset the normalization scaler
self.normalize_scaler = target_normalize_scaler
@torch.no_grad()
def merge_out(self: Module, merge_out_weight=1.0):
# make sure it is positive
merge_out_weight = abs(merge_out_weight)
# merging out is just merging in the negative of the weight
self.merge_in(merge_weight=-merge_out_weight)
@torch.no_grad()
def merge_in(self: Module, merge_weight=1.0):
# get up/down weight
up_weight = self.lora_up.weight.clone().float()
down_weight = self.lora_down.weight.clone().float()
# extract weight from org_module
org_sd = self.org_module[0].state_dict()
orig_dtype = org_sd["weight"].dtype
weight = org_sd["weight"].float()
multiplier = merge_weight
scale = self.scale
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
scale = scale * self.scalar
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + multiplier * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + multiplier * conved * scale
# set weight to org_module
org_sd["weight"] = weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
class ToolkitNetworkMixin:
def __init__(
@@ -210,6 +270,7 @@ class ToolkitNetworkMixin:
self._is_normalizing: bool = False
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_merged_in = False
# super().__init__(*args, **kwargs)
def get_keymap(self: Network):
@@ -326,7 +387,6 @@ class ToolkitNetworkMixin:
self.torch_multiplier = tensor_multiplier.clone().detach()
@property
def multiplier(self) -> Union[float, List[float]]:
return self._multiplier
@@ -396,3 +456,15 @@ class ToolkitNetworkMixin:
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():
module.apply_stored_normalizer(target_normalize_scaler)
def merge_in(self, merge_weight=1.0):
self.is_merged_in = True
for module in self.get_all_modules():
module.merge_in(merge_weight)
def merge_out(self, merge_weight=1.0):
if not self.is_merged_in:
return
self.is_merged_in = False
for module in self.get_all_modules():
module.merge_out(merge_weight)