mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Switched to new bucket system that matched sdxl trained buckets. Fixed requirements. Updated embeddings to work with sdxl. Added method to train lora with an embedding at the trigger. Still testing but works amazingly well from what I can see
This commit is contained in:
@@ -228,9 +228,15 @@ class ToolkitNetworkMixin:
|
||||
|
||||
return keymap
|
||||
|
||||
def save_weights(self: Network, file, dtype=torch.float16, metadata=None):
|
||||
def save_weights(
|
||||
self: Network,
|
||||
file, dtype=torch.float16,
|
||||
metadata=None,
|
||||
extra_state_dict: Optional[OrderedDict] = None
|
||||
):
|
||||
keymap = self.get_keymap()
|
||||
|
||||
|
||||
save_keymap = {}
|
||||
if keymap is not None:
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
@@ -249,6 +255,13 @@ class ToolkitNetworkMixin:
|
||||
save_key = save_keymap[key] if key in save_keymap else key
|
||||
save_dict[save_key] = v
|
||||
|
||||
if extra_state_dict is not None:
|
||||
# add extra items to state dict
|
||||
for key in list(extra_state_dict.keys()):
|
||||
v = extra_state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
save_dict[key] = v
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||
@@ -275,8 +288,21 @@ class ToolkitNetworkMixin:
|
||||
load_key = keymap[key] if key in keymap else key
|
||||
load_sd[load_key] = value
|
||||
|
||||
# extract extra items from state dict
|
||||
current_state_dict = self.state_dict()
|
||||
extra_dict = OrderedDict()
|
||||
to_delete = []
|
||||
for key in list(load_sd.keys()):
|
||||
if key not in current_state_dict:
|
||||
extra_dict[key] = load_sd[key]
|
||||
to_delete.append(key)
|
||||
for key in to_delete:
|
||||
del load_sd[key]
|
||||
|
||||
info = self.load_state_dict(load_sd, False)
|
||||
return info
|
||||
if len(extra_dict.keys()) == 0:
|
||||
extra_dict = None
|
||||
return extra_dict
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
|
||||
Reference in New Issue
Block a user