mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
Merge changes from dev
This commit is contained in:
@@ -3,6 +3,9 @@ import os
|
||||
from collections import namedtuple
|
||||
import enum
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import sd_models, cache, errors, hashes, shared
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
@@ -115,6 +118,29 @@ class NetworkModule:
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.ops = None
|
||||
self.extra_kwargs = {}
|
||||
if isinstance(self.sd_module, nn.Conv2d):
|
||||
self.ops = F.conv2d
|
||||
self.extra_kwargs = {
|
||||
'stride': self.sd_module.stride,
|
||||
'padding': self.sd_module.padding
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.Linear):
|
||||
self.ops = F.linear
|
||||
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||
self.ops = F.layer_norm
|
||||
self.extra_kwargs = {
|
||||
'normalized_shape': self.sd_module.normalized_shape,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||
self.ops = F.group_norm
|
||||
self.extra_kwargs = {
|
||||
'num_groups': self.sd_module.num_groups,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
@@ -137,7 +163,7 @@ class NetworkModule:
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if len(output_shape) == 4:
|
||||
@@ -155,5 +181,10 @@ class NetworkModule:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, y):
|
||||
raise NotImplementedError()
|
||||
"""A general forward implementation for all modules"""
|
||||
if self.ops is None:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||
|
||||
|
||||
@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.weight.shape
|
||||
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = self.weight.to(orig_weight.device)
|
||||
if self.ex_bias is not None:
|
||||
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
ex_bias = self.ex_bias.to(orig_weight.device)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
|
||||
@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
|
||||
self.w2b = weights.w["b2.weight"]
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
|
||||
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
||||
@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
|
||||
self.t2 = weights.w.get("hada_t2")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
|
||||
if self.t1 is not None:
|
||||
output_shape = [w1a.size(1), w1b.size(1)]
|
||||
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t1 = self.t1.to(orig_weight.device)
|
||||
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||
output_shape += t1.shape[2:]
|
||||
else:
|
||||
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
|
||||
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||
|
||||
if self.t2 is not None:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t2 = self.t2.to(orig_weight.device)
|
||||
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
else:
|
||||
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||
|
||||
@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
|
||||
self.on_input = weights.w["on_input"].item()
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w = self.w.to(orig_weight.device)
|
||||
|
||||
output_shape = [w.size(0), orig_weight.size(1)]
|
||||
if self.on_input:
|
||||
|
||||
@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
if self.w1 is not None:
|
||||
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1 = self.w1.to(orig_weight.device)
|
||||
else:
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w1 = w1a @ w1b
|
||||
|
||||
if self.w2 is not None:
|
||||
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2 = self.w2.to(orig_weight.device)
|
||||
elif self.t2 is None:
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t2 = self.t2.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
|
||||
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||
|
||||
@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
|
||||
return module
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
up = self.up_model.weight.to(orig_weight.device)
|
||||
down = self.down_model.weight.to(orig_weight.device)
|
||||
|
||||
output_shape = [up.size(0), down.size(1)]
|
||||
if self.mid_model is not None:
|
||||
# cp-decomposition
|
||||
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
mid = self.mid_model.weight.to(orig_weight.device)
|
||||
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||
output_shape += mid.shape[2:]
|
||||
else:
|
||||
|
||||
@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.w_norm.shape
|
||||
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = self.w_norm.to(orig_weight.device)
|
||||
|
||||
if self.b_norm is not None:
|
||||
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
ex_bias = self.b_norm.to(orig_weight.device)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
||||
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||
|
||||
if self.is_kohya:
|
||||
@@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||
|
||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
R = oft_blocks.to(orig_weight.device)
|
||||
|
||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
@@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
)
|
||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||
|
||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
||||
output_shape = orig_weight.shape
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import gradio as gr
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
emb_db.skipped_embeddings[name] = embedding
|
||||
|
||||
if failed_to_load_networks:
|
||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||
lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
|
||||
sd_hijack.model_hijack.comments.append(lora_not_found_message)
|
||||
if shared.opts.lora_not_found_warning_console:
|
||||
print(f'\n{lora_not_found_message}\n')
|
||||
if shared.opts.lora_not_found_gradio_warning:
|
||||
gr.Warning(lora_not_found_message)
|
||||
|
||||
purge_networks_from_memory()
|
||||
|
||||
@@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
if getattr(self, 'fp16_weight', None) is None:
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
else:
|
||||
weight = self.fp16_weight.clone().to(self.weight.device)
|
||||
bias = getattr(self, 'fp16_bias', None)
|
||||
if bias is not None:
|
||||
bias = bias.clone().to(self.bias.device)
|
||||
updown, ex_bias = module.calc_updown(weight)
|
||||
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
if len(weight.shape) == 4 and weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
|
||||
self.weight += updown
|
||||
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
|
||||
if ex_bias is not None and hasattr(self, 'bias'):
|
||||
if self.bias is None:
|
||||
self.bias = torch.nn.Parameter(ex_bias)
|
||||
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
|
||||
else:
|
||||
self.bias += ex_bias
|
||||
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
@@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
|
||||
def network_forward(module, input, original_forward):
|
||||
def network_forward(org_module, input, original_forward):
|
||||
"""
|
||||
Old way of applying Lora by executing operations during layer's forward.
|
||||
Stacking many loras this way results in big performance degradation.
|
||||
"""
|
||||
|
||||
if len(loaded_networks) == 0:
|
||||
return original_forward(module, input)
|
||||
return original_forward(org_module, input)
|
||||
|
||||
input = devices.cond_cast_unet(input)
|
||||
|
||||
network_restore_weights_from_backup(module)
|
||||
network_reset_cached_weight(module)
|
||||
network_restore_weights_from_backup(org_module)
|
||||
network_reset_cached_weight(org_module)
|
||||
|
||||
y = original_forward(module, input)
|
||||
y = original_forward(org_module, input)
|
||||
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
network_layer_name = getattr(org_module, 'network_layer_name', None)
|
||||
for lora in loaded_networks:
|
||||
module = lora.modules.get(network_layer_name, None)
|
||||
if module is None:
|
||||
|
||||
@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
||||
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
|
||||
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
|
||||
}))
|
||||
|
||||
|
||||
|
||||
@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
self.slider_preferred_weight = None
|
||||
self.edit_notes = None
|
||||
|
||||
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
|
||||
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
user_metadata["description"] = desc
|
||||
user_metadata["sd version"] = sd_version
|
||||
user_metadata["activation text"] = activation_text
|
||||
user_metadata["preferred weight"] = preferred_weight
|
||||
user_metadata["negative text"] = negative_text
|
||||
user_metadata["notes"] = notes
|
||||
|
||||
self.write_user_metadata(name, user_metadata)
|
||||
@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
|
||||
user_metadata.get('activation text', ''),
|
||||
float(user_metadata.get('preferred weight', 0.0)),
|
||||
user_metadata.get('negative text', ''),
|
||||
gr.update(visible=True if tags else False),
|
||||
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
|
||||
]
|
||||
@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
self.taginfo = gr.HighlightedText(label="Training dataset tags")
|
||||
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
|
||||
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
|
||||
|
||||
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
|
||||
with gr.Row() as row_random_prompt:
|
||||
with gr.Column(scale=8):
|
||||
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
||||
@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
self.taginfo,
|
||||
self.edit_activation_text,
|
||||
self.slider_preferred_weight,
|
||||
self.edit_negative_text,
|
||||
row_random_prompt,
|
||||
random_prompt,
|
||||
]
|
||||
@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
self.select_sd_version,
|
||||
self.edit_activation_text,
|
||||
self.slider_preferred_weight,
|
||||
self.edit_negative_text,
|
||||
self.edit_notes,
|
||||
]
|
||||
|
||||
|
||||
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
|
||||
|
||||
@@ -48,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
if activation_text:
|
||||
item["prompt"] += " + " + quote_js(" " + activation_text)
|
||||
|
||||
negative_prompt = item["user_metadata"].get("negative text")
|
||||
item["negative_prompt"] = quote_js("")
|
||||
if negative_prompt:
|
||||
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
|
||||
|
||||
sd_version = item["user_metadata"].get("sd version")
|
||||
if sd_version in network.SdVersion.__members__:
|
||||
item["sd_version"] = sd_version
|
||||
|
||||
Reference in New Issue
Block a user