Added LoKr support

This commit is contained in:
Jaret Burkett
2025-03-02 06:57:50 -07:00
parent 60539c0b0f
commit b16819f8e7
4 changed files with 215 additions and 113 deletions

View File

@@ -272,6 +272,9 @@ class ToolkitModuleMixin:
# if self.__class__.__name__ == "DoRAModule":
# # return dora forward
# return self.dora_forward(x, *args, **kwargs)
if self.__class__.__name__ == "LokrModule":
return self._call_forward(x)
org_forwarded = self.org_forward(x, *args, **kwargs)
@@ -540,6 +543,17 @@ class ToolkitNetworkMixin:
new_save_dict[new_key] = value
save_dict = new_save_dict
if self.network_type.lower() == "lokr":
new_save_dict = {}
for key, value in save_dict.items():
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
new_key = key
new_key = new_key.replace('lora_transformer_', 'lycoris_')
new_save_dict[new_key] = value
save_dict = new_save_dict
if metadata is None:
metadata = OrderedDict()
@@ -585,6 +599,10 @@ class ToolkitNetworkMixin:
load_key = load_key.replace('.', '$$')
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
if self.network_type.lower() == "lokr":
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
load_key = load_key.replace('lycoris_', 'lora_transformer_')
load_sd[load_key] = value
@@ -617,8 +635,18 @@ class ToolkitNetworkMixin:
multiplier = self._multiplier
# get first module
first_module = self.get_all_modules()[0]
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
if hasattr(first_module, 'lora_down'):
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
elif hasattr(first_module, 'lokr_w1'):
device = first_module.lokr_w1.device
dtype = first_module.lokr_w1.dtype
elif hasattr(first_module, 'lokr_w1_a'):
device = first_module.lokr_w1_a.device
dtype = first_module.lokr_w1_a.dtype
else:
raise ValueError("Unknown module type")
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):