Tons of bug fixes and improvements to special training. Fixed slider training.

This commit is contained in:
Jaret Burkett
2023-12-09 16:38:10 -07:00
parent eaec2f5a52
commit eaa0fb6253
9 changed files with 639 additions and 74 deletions

View File

@@ -206,6 +206,7 @@ def load_t2i_model(
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
def save_ip_adapter_from_diffusers(
combined_state_dict: 'OrderedDict',
output_file: str,
@@ -241,3 +242,58 @@ def load_ip_adapter_model(
return combined_state_dict
else:
return torch.load(path_to_file, map_location=device)
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
lora_keymap = OrderedDict()
# see if we have dual text encoders " a key that starts with conditioner.embedders.1
has_dual_text_encoders = False
for key in model_keymap:
if key.startswith('conditioner.embedders.1'):
has_dual_text_encoders = True
break
# map through the keys and values
for key, value in model_keymap.items():
# ignore bias weights
if key.endswith('bias'):
continue
if key.endswith('.weight'):
# remove the .weight
key = key[:-7]
if value.endswith(".weight"):
# remove the .weight
value = value[:-7]
# unet for all
key = key.replace('model.diffusion_model', 'lora_unet')
if value.startswith('unet'):
value = f"lora_{value}"
# text encoder
if has_dual_text_encoders:
key = key.replace('conditioner.embedders.0', 'lora_te1')
key = key.replace('conditioner.embedders.1', 'lora_te2')
if value.startswith('te0') or value.startswith('te1'):
value = f"lora_{value}"
value.replace('lora_te1', 'lora_te2')
value.replace('lora_te0', 'lora_te1')
key = key.replace('cond_stage_model.transformer', 'lora_te')
if value.startswith('te_'):
value = f"lora_{value}"
# replace periods with underscores
key = key.replace('.', '_')
value = value.replace('.', '_')
# add all the weights
lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight"
lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias"
lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight"
lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias"
lora_keymap[f"{key}.alpha"] = f"{value}.alpha"
return lora_keymap