mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bugfixes
This commit is contained in:
@@ -344,8 +344,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
|
||||
if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]):
|
||||
continue
|
||||
if self.only_if_contains is not None:
|
||||
if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]):
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
@@ -101,6 +101,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
self.cp = False
|
||||
self.use_w1 = False
|
||||
self.use_w2 = False
|
||||
self.can_merge_in = True
|
||||
|
||||
self.shape = org_module.weight.shape
|
||||
if org_module.__class__.__name__ == 'Conv2d':
|
||||
|
||||
@@ -634,7 +634,10 @@ class ToolkitNetworkMixin:
|
||||
# without having to set it in every single module every time it changes
|
||||
multiplier = self._multiplier
|
||||
# get first module
|
||||
first_module = self.get_all_modules()[0]
|
||||
try:
|
||||
first_module = self.get_all_modules()[0]
|
||||
except IndexError:
|
||||
raise ValueError("There are not any lora modules in this network. Check your config and try again")
|
||||
|
||||
if hasattr(first_module, 'lora_down'):
|
||||
device = first_module.lora_down.weight.device
|
||||
|
||||
Reference in New Issue
Block a user