mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Added LoKr support
This commit is contained in:
@@ -272,6 +272,9 @@ class ToolkitModuleMixin:
|
||||
# if self.__class__.__name__ == "DoRAModule":
|
||||
# # return dora forward
|
||||
# return self.dora_forward(x, *args, **kwargs)
|
||||
|
||||
if self.__class__.__name__ == "LokrModule":
|
||||
return self._call_forward(x)
|
||||
|
||||
org_forwarded = self.org_forward(x, *args, **kwargs)
|
||||
|
||||
@@ -540,6 +543,17 @@ class ToolkitNetworkMixin:
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
new_key = key
|
||||
new_key = new_key.replace('lora_transformer_', 'lycoris_')
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
@@ -585,6 +599,10 @@ class ToolkitNetworkMixin:
|
||||
load_key = load_key.replace('.', '$$')
|
||||
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
||||
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
# lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1
|
||||
load_key = load_key.replace('lycoris_', 'lora_transformer_')
|
||||
|
||||
load_sd[load_key] = value
|
||||
|
||||
@@ -617,8 +635,18 @@ class ToolkitNetworkMixin:
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
|
||||
if hasattr(first_module, 'lora_down'):
|
||||
device = first_module.lora_down.weight.device
|
||||
dtype = first_module.lora_down.weight.dtype
|
||||
elif hasattr(first_module, 'lokr_w1'):
|
||||
device = first_module.lokr_w1.device
|
||||
dtype = first_module.lokr_w1.dtype
|
||||
elif hasattr(first_module, 'lokr_w1_a'):
|
||||
device = first_module.lokr_w1_a.device
|
||||
dtype = first_module.lokr_w1_a.dtype
|
||||
else:
|
||||
raise ValueError("Unknown module type")
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
|
||||
Reference in New Issue
Block a user