mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Allow short and long caption combinations like form the new captioning system. Merge the network into the model before inference and reextract when done. Doubles inference speed on locon models during inference. allow splitting a batch into individual components and run them through alone. Basicallt gradient accumulation with single batch size.
This commit is contained in:
@@ -103,7 +103,21 @@ class ToolkitModuleMixin:
|
||||
# this may get an additional positional arg or not
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
if not self.network_ref().is_active:
|
||||
skip = False
|
||||
network = self.network_ref()
|
||||
# skip if not active
|
||||
if not network.is_active:
|
||||
skip = True
|
||||
|
||||
# skip if is merged in
|
||||
if network.is_merged_in:
|
||||
skip = True
|
||||
|
||||
# skip if multiplier is 0
|
||||
if network._multiplier == 0:
|
||||
skip = True
|
||||
|
||||
if skip:
|
||||
# network is not active, avoid doing anything
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
@@ -191,6 +205,52 @@ class ToolkitModuleMixin:
|
||||
# reset the normalization scaler
|
||||
self.normalize_scaler = target_normalize_scaler
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_out(self: Module, merge_out_weight=1.0):
|
||||
# make sure it is positive
|
||||
merge_out_weight = abs(merge_out_weight)
|
||||
# merging out is just merging in the negative of the weight
|
||||
self.merge_in(merge_weight=-merge_out_weight)
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_in(self: Module, merge_weight=1.0):
|
||||
# get up/down weight
|
||||
up_weight = self.lora_up.weight.clone().float()
|
||||
down_weight = self.lora_down.weight.clone().float()
|
||||
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
orig_dtype = org_sd["weight"].dtype
|
||||
weight = org_sd["weight"].float()
|
||||
|
||||
multiplier = merge_weight
|
||||
scale = self.scale
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
scale = scale * self.scalar
|
||||
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + multiplier * conved * scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
|
||||
class ToolkitNetworkMixin:
|
||||
def __init__(
|
||||
@@ -210,6 +270,7 @@ class ToolkitNetworkMixin:
|
||||
self._is_normalizing: bool = False
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
# super().__init__(*args, **kwargs)
|
||||
|
||||
def get_keymap(self: Network):
|
||||
@@ -326,7 +387,6 @@ class ToolkitNetworkMixin:
|
||||
|
||||
self.torch_multiplier = tensor_multiplier.clone().detach()
|
||||
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
return self._multiplier
|
||||
@@ -396,3 +456,15 @@ class ToolkitNetworkMixin:
|
||||
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
|
||||
for module in self.get_all_modules():
|
||||
module.apply_stored_normalizer(target_normalize_scaler)
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
self.is_merged_in = True
|
||||
for module in self.get_all_modules():
|
||||
module.merge_in(merge_weight)
|
||||
|
||||
def merge_out(self, merge_weight=1.0):
|
||||
if not self.is_merged_in:
|
||||
return
|
||||
self.is_merged_in = False
|
||||
for module in self.get_all_modules():
|
||||
module.merge_out(merge_weight)
|
||||
|
||||
Reference in New Issue
Block a user