Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs

This commit is contained in:
Jaret Burkett
2024-01-28 08:20:03 -07:00
parent f17ad8d794
commit 92b9c71d44
10 changed files with 352 additions and 56 deletions

View File

@@ -215,12 +215,17 @@ def save_ip_adapter_from_diffusers(
output_file: str,
meta: 'OrderedDict',
dtype=get_torch_dtype('fp16'),
direct_save: bool = False
):
# todo: test compatibility with non diffusers
converted_state_dict = OrderedDict()
for module_name, state_dict in combined_state_dict.items():
for key, value in state_dict.items():
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
if direct_save:
converted_state_dict[module_name] = state_dict.detach().to('cpu', dtype=dtype)
else:
for key, value in state_dict.items():
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
@@ -230,12 +235,15 @@ def save_ip_adapter_from_diffusers(
def load_ip_adapter_model(
path_to_file,
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
dtype: torch.dtype = torch.float32,
direct_load: bool = False
):
# check if it is safetensors or checkpoint
if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict()
if direct_load:
return raw_state_dict
for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.')
module_name = key_split.pop(0)