Fixed Dora implementation. Still highly experimental

This commit is contained in:
Jaret Burkett
2024-02-24 10:26:01 -07:00
parent 1bd94f0f01
commit f965a1299f
5 changed files with 128 additions and 33 deletions

View File

@@ -52,8 +52,14 @@ def broadcast_and_multiply(tensor, multiplier):
for _ in range(num_extra_dims):
multiplier = multiplier.unsqueeze(-1)
# Multiplying the broadcasted tensor with the output tensor
result = tensor * multiplier
try:
# Multiplying the broadcasted tensor with the output tensor
result = tensor * multiplier
except RuntimeError as e:
print(e)
print(tensor.size())
print(multiplier.size())
raise e
return result
@@ -248,9 +254,9 @@ class ToolkitModuleMixin:
# network is not active, avoid doing anything
return self.org_forward(x, *args, **kwargs)
if self.__class__.__name__ == "DoRAModule":
# return dora forward
return self.dora_forward(x, *args, **kwargs)
# if self.__class__.__name__ == "DoRAModule":
# # return dora forward
# return self.dora_forward(x, *args, **kwargs)
org_forwarded = self.org_forward(x, *args, **kwargs)
lora_output = self._call_forward(x)
@@ -263,7 +269,27 @@ class ToolkitModuleMixin:
# todo check if this is correct, do we just concat when doing cfg?
multiplier = multiplier.repeat_interleave(num_interleaves)
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
scaled_lora_output = broadcast_and_multiply(lora_output, multiplier)
if self.__class__.__name__ == "DoRAModule":
# ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417
# x = dropout(x)
# todo this wont match the dropout applied to the lora
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
lx = self.dropout(x)
# normal dropout
elif self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(x, p=self.dropout)
else:
lx = x
lora_weight = self.lora_up.weight @ self.lora_down.weight
# scale it here
# todo handle our batch split scalers for slider training. For now take the mean of them
scale = multiplier.mean()
scaled_lora_weight = lora_weight * scale
scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight)
x = org_forwarded + scaled_lora_output
return x
def enable_gradient_checkpointing(self: Module):
@@ -413,12 +439,12 @@ class ToolkitNetworkMixin:
new_keymap = {}
for ldm_key, diffusers_key in keymap.items():
ldm_key = ldm_key.replace('.alpha', '.magnitude')
ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
# ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
# ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
diffusers_key = diffusers_key.replace('.alpha', '.magnitude')
diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
# diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
# diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
new_keymap[ldm_key] = diffusers_key
@@ -513,12 +539,8 @@ class ToolkitNetworkMixin:
multiplier = self._multiplier
# get first module
first_module = self.get_all_modules()[0]
if self.network_type.lower() == 'dora':
device = first_module.lora_down.device
dtype = first_module.lora_down.dtype
else:
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):