diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 4dd94b0f..aa221480 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -209,7 +209,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): ignore_if_contains = [] self.ignore_if_contains = ignore_if_contains self.transformer_only = transformer_only - self.base_model_ref = weakref.ref(base_model) + self.base_model_ref = None + if base_model is not None: + self.base_model_ref = weakref.ref(base_model) self.only_if_contains: Union[List, None] = only_if_contains diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index b52af32b..826d7f09 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -547,7 +547,8 @@ class ToolkitNetworkMixin: save_dict = new_save_dict - save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict) + if self.base_model_ref is not None: + save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict) return save_dict def save_weights( @@ -586,7 +587,8 @@ class ToolkitNetworkMixin: # probably a state dict weights_sd = file - weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd) + if self.base_model_ref is not None: + weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd) load_sd = OrderedDict() for key, value in weights_sd.items():