Initial training script for photomaker training. Needs a little more work.

This commit is contained in:
Jaret Burkett
2024-01-15 18:46:26 -07:00
parent 5276975fb0
commit eebd3c8212
8 changed files with 1183 additions and 24 deletions

View File

@@ -246,6 +246,25 @@ def load_ip_adapter_model(
else:
return torch.load(path_to_file, map_location=device)
def load_custom_adapter_model(
path_to_file,
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
):
# check if it is safetensors or checkpoint
if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict()
for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.')
module_name = key_split.pop(0)
if module_name not in combined_state_dict:
combined_state_dict[module_name] = OrderedDict()
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
return combined_state_dict
else:
return torch.load(path_to_file, map_location=device)
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
lora_keymap = OrderedDict()