mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Fixed Dora implementation. Still highly experimental
This commit is contained in:
@@ -52,8 +52,14 @@ def broadcast_and_multiply(tensor, multiplier):
|
||||
for _ in range(num_extra_dims):
|
||||
multiplier = multiplier.unsqueeze(-1)
|
||||
|
||||
# Multiplying the broadcasted tensor with the output tensor
|
||||
result = tensor * multiplier
|
||||
try:
|
||||
# Multiplying the broadcasted tensor with the output tensor
|
||||
result = tensor * multiplier
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(tensor.size())
|
||||
print(multiplier.size())
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
@@ -248,9 +254,9 @@ class ToolkitModuleMixin:
|
||||
# network is not active, avoid doing anything
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
if self.__class__.__name__ == "DoRAModule":
|
||||
# return dora forward
|
||||
return self.dora_forward(x, *args, **kwargs)
|
||||
# if self.__class__.__name__ == "DoRAModule":
|
||||
# # return dora forward
|
||||
# return self.dora_forward(x, *args, **kwargs)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
@@ -263,7 +269,27 @@ class ToolkitModuleMixin:
|
||||
# todo check if this is correct, do we just concat when doing cfg?
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
|
||||
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||
scaled_lora_output = broadcast_and_multiply(lora_output, multiplier)
|
||||
|
||||
if self.__class__.__name__ == "DoRAModule":
|
||||
# ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417
|
||||
# x = dropout(x)
|
||||
# todo this wont match the dropout applied to the lora
|
||||
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
|
||||
lx = self.dropout(x)
|
||||
# normal dropout
|
||||
elif self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(x, p=self.dropout)
|
||||
else:
|
||||
lx = x
|
||||
lora_weight = self.lora_up.weight @ self.lora_down.weight
|
||||
# scale it here
|
||||
# todo handle our batch split scalers for slider training. For now take the mean of them
|
||||
scale = multiplier.mean()
|
||||
scaled_lora_weight = lora_weight * scale
|
||||
scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight)
|
||||
|
||||
x = org_forwarded + scaled_lora_output
|
||||
return x
|
||||
|
||||
def enable_gradient_checkpointing(self: Module):
|
||||
@@ -413,12 +439,12 @@ class ToolkitNetworkMixin:
|
||||
new_keymap = {}
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
ldm_key = ldm_key.replace('.alpha', '.magnitude')
|
||||
ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
|
||||
ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
|
||||
# ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
|
||||
# ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
|
||||
|
||||
diffusers_key = diffusers_key.replace('.alpha', '.magnitude')
|
||||
diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
|
||||
diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
|
||||
# diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
|
||||
# diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
|
||||
|
||||
new_keymap[ldm_key] = diffusers_key
|
||||
|
||||
@@ -513,12 +539,8 @@ class ToolkitNetworkMixin:
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
if self.network_type.lower() == 'dora':
|
||||
device = first_module.lora_down.device
|
||||
dtype = first_module.lora_down.dtype
|
||||
else:
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
|
||||
Reference in New Issue
Block a user