mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added early DoRA support, but will change shortly. Dont use right now.
This commit is contained in:
@@ -19,9 +19,10 @@ if TYPE_CHECKING:
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.models.DoRA import DoRAModule
|
||||
|
||||
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
|
||||
Module = Union['LoConSpecialModule', 'LoRAModule']
|
||||
Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule']
|
||||
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
@@ -247,6 +248,10 @@ class ToolkitModuleMixin:
|
||||
# network is not active, avoid doing anything
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
if self.__class__.__name__ == "DoRAModule":
|
||||
# return dora forward
|
||||
return self.dora_forward(x, *args, **kwargs)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
lora_output = self._call_forward(x)
|
||||
multiplier = self.network_ref().torch_multiplier
|
||||
@@ -276,6 +281,8 @@ class ToolkitModuleMixin:
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_in(self: Module, merge_weight=1.0):
|
||||
if not self.can_merge_in:
|
||||
return
|
||||
# get up/down weight
|
||||
up_weight = self.lora_up.weight.clone().float()
|
||||
down_weight = self.lora_down.weight.clone().float()
|
||||
@@ -400,6 +407,23 @@ class ToolkitNetworkMixin:
|
||||
# get keymap from weights
|
||||
keymap = get_lora_keymap_from_model_keymap(keymap)
|
||||
|
||||
# upgrade keymaps for DoRA
|
||||
if self.network_type.lower() == 'dora':
|
||||
if keymap is not None:
|
||||
new_keymap = {}
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
ldm_key = ldm_key.replace('.alpha', '.magnitude')
|
||||
ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
|
||||
ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
|
||||
|
||||
diffusers_key = diffusers_key.replace('.alpha', '.magnitude')
|
||||
diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
|
||||
diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
|
||||
|
||||
new_keymap[ldm_key] = diffusers_key
|
||||
|
||||
keymap = new_keymap
|
||||
|
||||
return keymap
|
||||
|
||||
def save_weights(
|
||||
@@ -489,8 +513,12 @@ class ToolkitNetworkMixin:
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
if self.network_type.lower() == 'dora':
|
||||
device = first_module.lora_down.device
|
||||
dtype = first_module.lora_down.dtype
|
||||
else:
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
@@ -559,6 +587,8 @@ class ToolkitNetworkMixin:
|
||||
self._update_checkpointing()
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
if self.network_type.lower() == 'dora':
|
||||
return
|
||||
self.is_merged_in = True
|
||||
for module in self.get_all_modules():
|
||||
module.merge_in(merge_weight)
|
||||
|
||||
Reference in New Issue
Block a user