From 71da78c8af1ae2fb67b0acff3d012fc601b57906 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 28 Aug 2023 12:42:57 -0600 Subject: [PATCH] improved normalization for a network with varrying batch network weights --- toolkit/lora_special.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 0a477087..1dc4dbc4 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -158,17 +158,25 @@ class LoRAModule(torch.nn.Module): def forward(self, x): org_forwarded = self.org_forward(x) lora_output = self._call_forward(x) + multiplier = self.get_multiplier(lora_output) if self.is_normalizing: with torch.no_grad(): - # do this calculation without multiplier + + # do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier + if isinstance(multiplier, torch.Tensor): + norm_multiplier = multiplier.clone().detach() * 10 + norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0) + else: + norm_multiplier = multiplier + # get a dim array from orig forward that had index of all dimensions except the batch and channel # Calculate the target magnitude for the combined output orig_max = torch.max(torch.abs(org_forwarded)) # Calculate the additional increase in magnitude that lora_output would introduce - potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded)) + potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded)) epsilon = 1e-6 # Small constant to avoid division by zero @@ -182,8 +190,6 @@ class LoRAModule(torch.nn.Module): lora_output *= normalize_scaler - multiplier = self.get_multiplier(lora_output) - return org_forwarded + (lora_output * multiplier) def enable_gradient_checkpointing(self):