mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added IP adapter training. Not functioning correctly yet
This commit is contained in:
@@ -193,3 +193,42 @@ def load_t2i_model(
|
||||
# todo see if we need to convert dict
|
||||
converted_state_dict[key] = value.detach().to(device, dtype=dtype)
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
|
||||
|
||||
def save_ip_adapter_from_diffusers(
|
||||
combined_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
dtype=get_torch_dtype('fp16'),
|
||||
):
|
||||
# todo: test compatibility with non diffusers
|
||||
converted_state_dict = OrderedDict()
|
||||
for module_name, state_dict in combined_state_dict.items():
|
||||
for key, value in state_dict.items():
|
||||
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def load_ip_adapter_model(
|
||||
path_to_file,
|
||||
device: Union[str] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
# check if it is safetensors or checkpoint
|
||||
if path_to_file.endswith('.safetensors'):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
combined_state_dict = OrderedDict()
|
||||
for combo_key, value in raw_state_dict.items():
|
||||
key_split = combo_key.split('.')
|
||||
module_name = key_split.pop(0)
|
||||
if module_name not in combined_state_dict:
|
||||
combined_state_dict[module_name] = OrderedDict()
|
||||
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
|
||||
return combined_state_dict
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
Reference in New Issue
Block a user