Added IP adapter training. Not functioning correctly yet

This commit is contained in:
Jaret Burkett
2023-09-24 02:39:43 -06:00
parent 19255cdc7c
commit 830e87cb87
9 changed files with 336 additions and 53 deletions

View File

@@ -193,3 +193,42 @@ def load_t2i_model(
# todo see if we need to convert dict
converted_state_dict[key] = value.detach().to(device, dtype=dtype)
return converted_state_dict
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
def save_ip_adapter_from_diffusers(
combined_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
dtype=get_torch_dtype('fp16'),
):
# todo: test compatibility with non diffusers
converted_state_dict = OrderedDict()
for module_name, state_dict in combined_state_dict.items():
for key, value in state_dict.items():
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def load_ip_adapter_model(
path_to_file,
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
):
# check if it is safetensors or checkpoint
if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict()
for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.')
module_name = key_split.pop(0)
if module_name not in combined_state_dict:
combined_state_dict[module_name] = OrderedDict()
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
return combined_state_dict
else:
return torch.load(path_to_file, map_location=device)