mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-01 22:09:46 +00:00
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
124 lines
4.6 KiB
Python
124 lines
4.6 KiB
Python
# Taken from https://github.com/comfyanonymous/ComfyUI
|
|
# This file is only for reference, and not used in the backend or runtime.
|
|
|
|
|
|
import ldm_patched.modules.utils
|
|
import ldm_patched.utils.path_utils
|
|
import torch
|
|
|
|
def load_hypernetwork_patch(path, strength):
|
|
sd = ldm_patched.modules.utils.load_torch_file(path, safe_load=True)
|
|
activation_func = sd.get('activation_func', 'linear')
|
|
is_layer_norm = sd.get('is_layer_norm', False)
|
|
use_dropout = sd.get('use_dropout', False)
|
|
activate_output = sd.get('activate_output', False)
|
|
last_layer_dropout = sd.get('last_layer_dropout', False)
|
|
|
|
valid_activation = {
|
|
"linear": torch.nn.Identity,
|
|
"relu": torch.nn.ReLU,
|
|
"leakyrelu": torch.nn.LeakyReLU,
|
|
"elu": torch.nn.ELU,
|
|
"swish": torch.nn.Hardswish,
|
|
"tanh": torch.nn.Tanh,
|
|
"sigmoid": torch.nn.Sigmoid,
|
|
"softsign": torch.nn.Softsign,
|
|
"mish": torch.nn.Mish,
|
|
}
|
|
|
|
if activation_func not in valid_activation:
|
|
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
|
return None
|
|
|
|
out = {}
|
|
|
|
for d in sd:
|
|
try:
|
|
dim = int(d)
|
|
except:
|
|
continue
|
|
|
|
output = []
|
|
for index in [0, 1]:
|
|
attn_weights = sd[dim][index]
|
|
keys = attn_weights.keys()
|
|
|
|
linears = filter(lambda a: a.endswith(".weight"), keys)
|
|
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
|
layers = []
|
|
|
|
i = 0
|
|
while i < len(linears):
|
|
lin_name = linears[i]
|
|
last_layer = (i == (len(linears) - 1))
|
|
penultimate_layer = (i == (len(linears) - 2))
|
|
|
|
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
|
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
|
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
|
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
|
layers.append(layer)
|
|
if activation_func != "linear":
|
|
if (not last_layer) or (activate_output):
|
|
layers.append(valid_activation[activation_func]())
|
|
if is_layer_norm:
|
|
i += 1
|
|
ln_name = linears[i]
|
|
ln_weight = attn_weights['{}.weight'.format(ln_name)]
|
|
ln_bias = attn_weights['{}.bias'.format(ln_name)]
|
|
ln = torch.nn.LayerNorm(ln_weight.shape[0])
|
|
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
|
|
layers.append(ln)
|
|
if use_dropout:
|
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
|
layers.append(torch.nn.Dropout(p=0.3))
|
|
i += 1
|
|
|
|
output.append(torch.nn.Sequential(*layers))
|
|
out[dim] = torch.nn.ModuleList(output)
|
|
|
|
class hypernetwork_patch:
|
|
def __init__(self, hypernet, strength):
|
|
self.hypernet = hypernet
|
|
self.strength = strength
|
|
def __call__(self, q, k, v, extra_options):
|
|
dim = k.shape[-1]
|
|
if dim in self.hypernet:
|
|
hn = self.hypernet[dim]
|
|
k = k + hn[0](k) * self.strength
|
|
v = v + hn[1](v) * self.strength
|
|
|
|
return q, k, v
|
|
|
|
def to(self, device):
|
|
for d in self.hypernet.keys():
|
|
self.hypernet[d] = self.hypernet[d].to(device)
|
|
return self
|
|
|
|
return hypernetwork_patch(out, strength)
|
|
|
|
class HypernetworkLoader:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model": ("MODEL",),
|
|
"hypernetwork_name": (ldm_patched.utils.path_utils.get_filename_list("hypernetworks"), ),
|
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "load_hypernetwork"
|
|
|
|
CATEGORY = "loaders"
|
|
|
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
|
hypernetwork_path = ldm_patched.utils.path_utils.get_full_path("hypernetworks", hypernetwork_name)
|
|
model_hypernetwork = model.clone()
|
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
|
if patch is not None:
|
|
model_hypernetwork.set_model_attn1_patch(patch)
|
|
model_hypernetwork.set_model_attn2_patch(patch)
|
|
return (model_hypernetwork,)
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"HypernetworkLoader": HypernetworkLoader
|
|
}
|