mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed bug that prevented using schnell training adapter
This commit is contained in:
@@ -209,7 +209,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
ignore_if_contains = []
|
||||
self.ignore_if_contains = ignore_if_contains
|
||||
self.transformer_only = transformer_only
|
||||
self.base_model_ref = weakref.ref(base_model)
|
||||
self.base_model_ref = None
|
||||
if base_model is not None:
|
||||
self.base_model_ref = weakref.ref(base_model)
|
||||
|
||||
self.only_if_contains: Union[List, None] = only_if_contains
|
||||
|
||||
|
||||
@@ -547,7 +547,8 @@ class ToolkitNetworkMixin:
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict)
|
||||
if self.base_model_ref is not None:
|
||||
save_dict = self.base_model_ref().convert_lora_weights_before_save(save_dict)
|
||||
return save_dict
|
||||
|
||||
def save_weights(
|
||||
@@ -586,7 +587,8 @@ class ToolkitNetworkMixin:
|
||||
# probably a state dict
|
||||
weights_sd = file
|
||||
|
||||
weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd)
|
||||
if self.base_model_ref is not None:
|
||||
weights_sd = self.base_model_ref().convert_lora_weights_before_load(weights_sd)
|
||||
|
||||
load_sd = OrderedDict()
|
||||
for key, value in weights_sd.items():
|
||||
|
||||
Reference in New Issue
Block a user