diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index c465f42..a639036 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -206,8 +206,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__.in_(LINEAR_MODULES) - is_conv2d = child_module.__class__.__name__.in_(CONV_MODULES) + is_linear = child_module.__class__.__name__ in LINEAR_MODULES + is_conv2d = child_module.__class__.__name__ in CONV_MODULES is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: