mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Initial training script for photomaker training. Needs a little more work.
This commit is contained in:
@@ -246,6 +246,25 @@ def load_ip_adapter_model(
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
def load_custom_adapter_model(
|
||||
path_to_file,
|
||||
device: Union[str] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
# check if it is safetensors or checkpoint
|
||||
if path_to_file.endswith('.safetensors'):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
combined_state_dict = OrderedDict()
|
||||
for combo_key, value in raw_state_dict.items():
|
||||
key_split = combo_key.split('.')
|
||||
module_name = key_split.pop(0)
|
||||
if module_name not in combined_state_dict:
|
||||
combined_state_dict[module_name] = OrderedDict()
|
||||
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
|
||||
return combined_state_dict
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
|
||||
def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
|
||||
lora_keymap = OrderedDict()
|
||||
|
||||
Reference in New Issue
Block a user