mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added working ilora trainer
This commit is contained in:
@@ -249,92 +249,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
skipped = []
|
||||
attached_modules = []
|
||||
for name, module in root_module.named_modules():
|
||||
if is_unet:
|
||||
module_name = module.__class__.__name__
|
||||
if module not in attached_modules:
|
||||
# if module.__class__.__name__ in target_replace_modules:
|
||||
# for child_name, child_module in module.named_modules():
|
||||
is_linear = module_name == 'LoRACompatibleLinear'
|
||||
is_conv2d = module_name == 'LoRACompatibleConv'
|
||||
# check if attn in name
|
||||
is_attention = "attentions" in name
|
||||
if not is_attention and attn_only:
|
||||
continue
|
||||
|
||||
if is_linear and self.lora_dim is None:
|
||||
continue
|
||||
if is_conv2d and self.conv_lora_dim is None:
|
||||
continue
|
||||
|
||||
is_conv2d_1x1 = is_conv2d and module.kernel_size == (1, 1)
|
||||
|
||||
if is_conv2d_1x1:
|
||||
pass
|
||||
|
||||
skip = False
|
||||
if any([word in name for word in self.ignore_if_contains]):
|
||||
skip = True
|
||||
|
||||
# see if it is over threshold
|
||||
if count_parameters(module) < parameter_threshold:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
lora_name = prefix + "." + name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
if modules_dim is not None:
|
||||
# モジュール指定あり
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
elif is_unet and block_dims is not None:
|
||||
# U-Netでblock_dims指定あり
|
||||
block_idx = get_block_index(lora_name)
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = block_dims[block_idx]
|
||||
alpha = block_alphas[block_idx]
|
||||
elif conv_block_dims is not None:
|
||||
dim = conv_block_dims[block_idx]
|
||||
alpha = conv_block_alphas[block_idx]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
else:
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
if dim is None or dim == 0:
|
||||
# skipした情報を出力
|
||||
if is_linear or is_conv2d_1x1 or (
|
||||
self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
)
|
||||
loras.append(lora)
|
||||
attached_modules.append(module)
|
||||
elif module.__class__.__name__ in target_replace_modules:
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ in LINEAR_MODULES
|
||||
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
||||
|
||||
Reference in New Issue
Block a user