Added working ilora trainer

This commit is contained in:
Jaret Burkett
2024-06-12 09:33:45 -06:00
parent 3f3636b788
commit cb5d28cba9
6 changed files with 261 additions and 196 deletions

View File

@@ -249,92 +249,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
skipped = []
attached_modules = []
for name, module in root_module.named_modules():
if is_unet:
module_name = module.__class__.__name__
if module not in attached_modules:
# if module.__class__.__name__ in target_replace_modules:
# for child_name, child_module in module.named_modules():
is_linear = module_name == 'LoRACompatibleLinear'
is_conv2d = module_name == 'LoRACompatibleConv'
# check if attn in name
is_attention = "attentions" in name
if not is_attention and attn_only:
continue
if is_linear and self.lora_dim is None:
continue
if is_conv2d and self.conv_lora_dim is None:
continue
is_conv2d_1x1 = is_conv2d and module.kernel_size == (1, 1)
if is_conv2d_1x1:
pass
skip = False
if any([word in name for word in self.ignore_if_contains]):
skip = True
# see if it is over threshold
if count_parameters(module) < parameter_threshold:
skip = True
if (is_linear or is_conv2d) and not skip:
lora_name = prefix + "." + name
lora_name = lora_name.replace(".", "_")
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
alpha = block_alphas[block_idx]
elif conv_block_dims is not None:
dim = conv_block_dims[block_idx]
alpha = conv_block_alphas[block_idx]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
alpha = self.alpha
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
else:
dim = None
alpha = None
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (
self.conv_lora_dim is not None or conv_block_dims is not None):
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
network=self,
parent=module,
use_bias=use_bias,
)
loras.append(lora)
attached_modules.append(module)
elif module.__class__.__name__ in target_replace_modules:
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ in LINEAR_MODULES
is_conv2d = child_module.__class__.__name__ in CONV_MODULES