Switched to new bucket system that matched sdxl trained buckets. Fixed requirements. Updated embeddings to work with sdxl. Added method to train lora with an embedding at the trigger. Still testing but works amazingly well from what I can see

This commit is contained in:
Jaret Burkett
2023-09-07 13:06:18 -06:00
parent 436bf0c6a3
commit 3feb663a51
10 changed files with 208 additions and 140 deletions

View File

@@ -228,9 +228,15 @@ class ToolkitNetworkMixin:
return keymap
def save_weights(self: Network, file, dtype=torch.float16, metadata=None):
def save_weights(
self: Network,
file, dtype=torch.float16,
metadata=None,
extra_state_dict: Optional[OrderedDict] = None
):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
@@ -249,6 +255,13 @@ class ToolkitNetworkMixin:
save_key = save_keymap[key] if key in save_keymap else key
save_dict[save_key] = v
if extra_state_dict is not None:
# add extra items to state dict
for key in list(extra_state_dict.keys()):
v = extra_state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
save_dict[key] = v
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(state_dict, metadata)
@@ -275,8 +288,21 @@ class ToolkitNetworkMixin:
load_key = keymap[key] if key in keymap else key
load_sd[load_key] = value
# extract extra items from state dict
current_state_dict = self.state_dict()
extra_dict = OrderedDict()
to_delete = []
for key in list(load_sd.keys()):
if key not in current_state_dict:
extra_dict[key] = load_sd[key]
to_delete.append(key)
for key in to_delete:
del load_sd[key]
info = self.load_state_dict(load_sd, False)
return info
if len(extra_dict.keys()) == 0:
extra_dict = None
return extra_dict
@property
def multiplier(self) -> Union[float, List[float]]: