mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Tons of bug fixes and improvements to special training. Fixed slider training.
This commit is contained in:
@@ -206,6 +206,7 @@ def load_t2i_model(
|
||||
|
||||
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
|
||||
|
||||
|
||||
def save_ip_adapter_from_diffusers(
|
||||
combined_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
@@ -241,3 +242,58 @@ def load_ip_adapter_model(
|
||||
return combined_state_dict
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
|
||||
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
|
||||
lora_keymap = OrderedDict()
|
||||
|
||||
# see if we have dual text encoders " a key that starts with conditioner.embedders.1
|
||||
has_dual_text_encoders = False
|
||||
for key in model_keymap:
|
||||
if key.startswith('conditioner.embedders.1'):
|
||||
has_dual_text_encoders = True
|
||||
break
|
||||
|
||||
# map through the keys and values
|
||||
for key, value in model_keymap.items():
|
||||
# ignore bias weights
|
||||
if key.endswith('bias'):
|
||||
continue
|
||||
if key.endswith('.weight'):
|
||||
# remove the .weight
|
||||
key = key[:-7]
|
||||
if value.endswith(".weight"):
|
||||
# remove the .weight
|
||||
value = value[:-7]
|
||||
|
||||
# unet for all
|
||||
key = key.replace('model.diffusion_model', 'lora_unet')
|
||||
if value.startswith('unet'):
|
||||
value = f"lora_{value}"
|
||||
|
||||
# text encoder
|
||||
if has_dual_text_encoders:
|
||||
key = key.replace('conditioner.embedders.0', 'lora_te1')
|
||||
key = key.replace('conditioner.embedders.1', 'lora_te2')
|
||||
if value.startswith('te0') or value.startswith('te1'):
|
||||
value = f"lora_{value}"
|
||||
value.replace('lora_te1', 'lora_te2')
|
||||
value.replace('lora_te0', 'lora_te1')
|
||||
|
||||
key = key.replace('cond_stage_model.transformer', 'lora_te')
|
||||
|
||||
if value.startswith('te_'):
|
||||
value = f"lora_{value}"
|
||||
|
||||
# replace periods with underscores
|
||||
key = key.replace('.', '_')
|
||||
value = value.replace('.', '_')
|
||||
|
||||
# add all the weights
|
||||
lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight"
|
||||
lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias"
|
||||
lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight"
|
||||
lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias"
|
||||
lora_keymap[f"{key}.alpha"] = f"{value}.alpha"
|
||||
|
||||
return lora_keymap
|
||||
|
||||
Reference in New Issue
Block a user