Merge changes from dev

This commit is contained in:
Sj-Si
2024-01-11 16:37:35 -05:00
117 changed files with 3368 additions and 4917 deletions

View File

@@ -3,6 +3,9 @@ import os
from collections import namedtuple
import enum
import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -115,6 +118,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.ops = None
self.extra_kwargs = {}
if isinstance(self.sd_module, nn.Conv2d):
self.ops = F.conv2d
self.extra_kwargs = {
'stride': self.sd_module.stride,
'padding': self.sd_module.padding
}
elif isinstance(self.sd_module, nn.Linear):
self.ops = F.linear
elif isinstance(self.sd_module, nn.LayerNorm):
self.ops = F.layer_norm
self.extra_kwargs = {
'normalized_shape': self.sd_module.normalized_shape,
'eps': self.sd_module.eps
}
elif isinstance(self.sd_module, nn.GroupNorm):
self.ops = F.group_norm
self.extra_kwargs = {
'num_groups': self.sd_module.num_groups,
'eps': self.sd_module.eps
}
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@@ -137,7 +163,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
@@ -155,5 +181,10 @@ class NetworkModule:
raise NotImplementedError()
def forward(self, x, y):
raise NotImplementedError()
"""A general forward implementation for all modules"""
if self.ops is None:
raise NotImplementedError()
else:
updown, ex_bias = self.calc_updown(self.sd_module.weight)
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)

View File

@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None

View File

@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape)

View File

@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)

View File

@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:

View File

@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight):
if self.w1 is not None:
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
w1 = self.w1.to(orig_weight.device)
else:
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b
if self.w2 is not None:
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
t2 = self.t2.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]

View File

@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module
def calc_updown(self, orig_weight):
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
up = self.up_model.weight.to(orig_weight.device)
down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:

View File

@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
ex_bias = self.b_norm.to(orig_weight.device)
else:
ex_bias = None

View File

@@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule):
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown(self, orig_weight):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
oft_blocks = self.oft_blocks.to(orig_weight.device)
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
if self.is_kohya:
@@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule):
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
R = oft_blocks.to(orig_weight.device)
# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
@@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule):
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)

View File

@@ -1,3 +1,4 @@
import gradio as gr
import logging
import os
import re
@@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
sd_hijack.model_hijack.comments.append(lora_not_found_message)
if shared.opts.lora_not_found_warning_console:
print(f'\n{lora_not_found_message}\n')
if shared.opts.lora_not_found_gradio_warning:
gr.Warning(lora_not_found_message)
purge_networks_from_memory()
@@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight)
if getattr(self, 'fp16_weight', None) is None:
weight = self.weight
bias = self.bias
else:
weight = self.fp16_weight.clone().to(self.weight.device)
bias = getattr(self, 'fp16_bias', None)
if bias is not None:
bias = bias.clone().to(self.bias.device)
updown, ex_bias = module.calc_updown(weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias)
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
self.bias += ex_bias
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
@@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names
def network_forward(module, input, original_forward):
def network_forward(org_module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_networks) == 0:
return original_forward(module, input)
return original_forward(org_module, input)
input = devices.cond_cast_unet(input)
network_restore_weights_from_backup(module)
network_reset_cached_weight(module)
network_restore_weights_from_backup(org_module)
network_reset_cached_weight(org_module)
y = original_forward(module, input)
y = original_forward(org_module, input)
network_layer_name = getattr(module, 'network_layer_name', None)
network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None)
if module is None:

View File

@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
}))

View File

@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.slider_preferred_weight = None
self.edit_notes = None
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight
user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata)
@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
]
@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt:
with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_negative_text,
row_random_prompt,
random_prompt,
]
@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.select_sd_version,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_negative_text,
self.edit_notes,
]
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)

View File

@@ -48,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)
negative_prompt = item["user_metadata"].get("negative text")
item["negative_prompt"] = quote_js("")
if negative_prompt:
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version