Added early DoRA support, but will change shortly. Dont use right now.

This commit is contained in:
Jaret Burkett
2024-02-23 05:55:41 -07:00
parent 9ffa8c3711
commit 1bd94f0f01
5 changed files with 140 additions and 5 deletions

View File

@@ -19,9 +19,10 @@ if TYPE_CHECKING:
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.models.DoRA import DoRAModule
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
Module = Union['LoConSpecialModule', 'LoRAModule']
Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule']
LINEAR_MODULES = [
'Linear',
@@ -247,6 +248,10 @@ class ToolkitModuleMixin:
# network is not active, avoid doing anything
return self.org_forward(x, *args, **kwargs)
if self.__class__.__name__ == "DoRAModule":
# return dora forward
return self.dora_forward(x, *args, **kwargs)
org_forwarded = self.org_forward(x, *args, **kwargs)
lora_output = self._call_forward(x)
multiplier = self.network_ref().torch_multiplier
@@ -276,6 +281,8 @@ class ToolkitModuleMixin:
@torch.no_grad()
def merge_in(self: Module, merge_weight=1.0):
if not self.can_merge_in:
return
# get up/down weight
up_weight = self.lora_up.weight.clone().float()
down_weight = self.lora_down.weight.clone().float()
@@ -400,6 +407,23 @@ class ToolkitNetworkMixin:
# get keymap from weights
keymap = get_lora_keymap_from_model_keymap(keymap)
# upgrade keymaps for DoRA
if self.network_type.lower() == 'dora':
if keymap is not None:
new_keymap = {}
for ldm_key, diffusers_key in keymap.items():
ldm_key = ldm_key.replace('.alpha', '.magnitude')
ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
diffusers_key = diffusers_key.replace('.alpha', '.magnitude')
diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
new_keymap[ldm_key] = diffusers_key
keymap = new_keymap
return keymap
def save_weights(
@@ -489,8 +513,12 @@ class ToolkitNetworkMixin:
multiplier = self._multiplier
# get first module
first_module = self.get_all_modules()[0]
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
if self.network_type.lower() == 'dora':
device = first_module.lora_down.device
dtype = first_module.lora_down.dtype
else:
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):
@@ -559,6 +587,8 @@ class ToolkitNetworkMixin:
self._update_checkpointing()
def merge_in(self, merge_weight=1.0):
if self.network_type.lower() == 'dora':
return
self.is_merged_in = True
for module in self.get_all_modules():
module.merge_in(merge_weight)