improved normalization for a network with varrying batch network weights

This commit is contained in:
Jaret Burkett
2023-08-28 12:42:57 -06:00
parent c446f768ea
commit 71da78c8af

View File

@@ -158,17 +158,25 @@ class LoRAModule(torch.nn.Module):
def forward(self, x):
org_forwarded = self.org_forward(x)
lora_output = self._call_forward(x)
multiplier = self.get_multiplier(lora_output)
if self.is_normalizing:
with torch.no_grad():
# do this calculation without multiplier
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
if isinstance(multiplier, torch.Tensor):
norm_multiplier = multiplier.clone().detach() * 10
norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0)
else:
norm_multiplier = multiplier
# get a dim array from orig forward that had index of all dimensions except the batch and channel
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded))
epsilon = 1e-6 # Small constant to avoid division by zero
@@ -182,8 +190,6 @@ class LoRAModule(torch.nn.Module):
lora_output *= normalize_scaler
multiplier = self.get_multiplier(lora_output)
return org_forwarded + (lora_output * multiplier)
def enable_gradient_checkpointing(self):