diff --git a/backend/misc/diffusers_state_dict.py b/backend/misc/diffusers_state_dict.py new file mode 100644 index 00000000..325d28b5 --- /dev/null +++ b/backend/misc/diffusers_state_dict.py @@ -0,0 +1,134 @@ +UNET_MAP_ATTENTIONS = { + "proj_in.weight", + "proj_in.bias", + "proj_out.weight", + "proj_out.bias", + "norm.weight", + "norm.bias", +} + +TRANSFORMER_BLOCKS = { + "norm1.weight", + "norm1.bias", + "norm2.weight", + "norm2.bias", + "norm3.weight", + "norm3.bias", + "attn1.to_q.weight", + "attn1.to_k.weight", + "attn1.to_v.weight", + "attn1.to_out.0.weight", + "attn1.to_out.0.bias", + "attn2.to_q.weight", + "attn2.to_k.weight", + "attn2.to_v.weight", + "attn2.to_out.0.weight", + "attn2.to_out.0.bias", + "ff.net.0.proj.weight", + "ff.net.0.proj.bias", + "ff.net.2.weight", + "ff.net.2.bias", +} + +UNET_MAP_RESNET = { + "in_layers.2.weight": "conv1.weight", + "in_layers.2.bias": "conv1.bias", + "emb_layers.1.weight": "time_emb_proj.weight", + "emb_layers.1.bias": "time_emb_proj.bias", + "out_layers.3.weight": "conv2.weight", + "out_layers.3.bias": "conv2.bias", + "skip_connection.weight": "conv_shortcut.weight", + "skip_connection.bias": "conv_shortcut.bias", + "in_layers.0.weight": "norm1.weight", + "in_layers.0.bias": "norm1.bias", + "out_layers.0.weight": "norm2.weight", + "out_layers.0.bias": "norm2.bias", +} + +UNET_MAP_BASIC = { + ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), + ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias") +} + + +def unet_to_diffusers(unet_config): + if "num_res_blocks" not in unet_config: + return {} + num_res_blocks = unet_config["num_res_blocks"] + channel_mult = unet_config["channel_mult"] + transformer_depth = unet_config["transformer_depth"][:] + transformer_depth_output = unet_config["transformer_depth_output"][:] + num_blocks = len(channel_mult) + + transformers_mid = unet_config.get("transformer_depth_middle", None) + + diffusers_unet_map = {} + for x in range(num_blocks): + n = 1 + (num_res_blocks[x] + 1) * x + for i in range(num_res_blocks[x]): + for b in UNET_MAP_RESNET: + diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) + for t in range(num_transformers): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) + n += 1 + for k in ["weight", "bias"]: + diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k) + + i = 0 + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b) + for t in range(transformers_mid): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b) + + for i, n in enumerate([0, 2]): + for b in UNET_MAP_RESNET: + diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) + + num_res_blocks = list(reversed(num_res_blocks)) + for x in range(num_blocks): + n = (num_res_blocks[x] + 1) * x + l = num_res_blocks[x] + 1 + for i in range(l): + c = 0 + for b in UNET_MAP_RESNET: + diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) + c += 1 + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: + c += 1 + for b in UNET_MAP_ATTENTIONS: + diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) + for t in range(num_transformers): + for b in TRANSFORMER_BLOCKS: + diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) + if i == l - 1: + for k in ["weight", "bias"]: + diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k) + n += 1 + + for k in UNET_MAP_BASIC: + diffusers_unet_map[k[1]] = k[0] + + return diffusers_unet_map diff --git a/backend/patcher/base.py b/backend/patcher/base.py new file mode 100644 index 00000000..2b12ac07 --- /dev/null +++ b/backend/patcher/base.py @@ -0,0 +1,500 @@ +# Model Patching API Template Extracted From ComfyUI +# The actual implementation for those APIs are from Forge, implemented from scratch (after forge-v1.0.1), +# and may have certain level of differences. + +import torch +import copy +import inspect + +from backend import memory_management, utils + +extra_weight_calculators = {} + + +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): + dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32) + lora_diff *= alpha + weight_calc = weight + lora_diff.type(weight.dtype) + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * weight_calc + else: + weight[:] = weight_calc + return weight + + +def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): + to = model_options["transformer_options"].copy() + + if "patches_replace" not in to: + to["patches_replace"] = {} + else: + to["patches_replace"] = to["patches_replace"].copy() + + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + else: + to["patches_replace"][name] = to["patches_replace"][name].copy() + + if transformer_index is not None: + block = (block_name, number, transformer_index) + else: + block = (block_name, number) + to["patches_replace"][name][block] = patch + model_options["transformer_options"] = to + return model_options + + +def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False): + model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + if disable_cfg1_optimization: + model_options["disable_cfg1_optimization"] = True + return model_options + + +def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False): + model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function] + if disable_cfg1_optimization: + model_options["disable_cfg1_optimization"] = True + return model_options + + +class ModelPatcher: + def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): + self.size = size + self.model = model + self.patches = {} + self.backup = {} + self.object_patches = {} + self.object_patches_backup = {} + self.model_options = {"transformer_options": {}} + self.model_size() + self.load_device = load_device + self.offload_device = offload_device + if current_device is None: + self.current_device = self.offload_device + else: + self.current_device = current_device + + self.weight_inplace_update = weight_inplace_update + + def model_size(self): + if self.size > 0: + return self.size + self.size = memory_management.module_size(self.model) + return self.size + + def clone(self): + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) + n.patches = {} + for k in self.patches: + n.patches[k] = self.patches[k][:] + + n.object_patches = self.object_patches.copy() + n.model_options = copy.deepcopy(self.model_options) + return n + + def is_clone(self, other): + if hasattr(other, 'model') and self.model is other.model: + return True + return False + + def memory_required(self, input_shape): + return self.model.memory_required(input_shape=input_shape) + + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): + if len(inspect.signature(sampler_cfg_function).parameters) == 3: + self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) # Old way + else: + self.model_options["sampler_cfg_function"] = sampler_cfg_function + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True + + def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): + self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) + + def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False): + self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization) + + def set_model_unet_function_wrapper(self, unet_wrapper_function): + self.model_options["model_function_wrapper"] = unet_wrapper_function + + def set_model_vae_encode_wrapper(self, wrapper_function): + self.model_options["model_vae_encode_wrapper"] = wrapper_function + + def set_model_vae_decode_wrapper(self, wrapper_function): + self.model_options["model_vae_decode_wrapper"] = wrapper_function + + def set_model_vae_regulation(self, vae_regulation): + self.model_options["model_vae_regulation"] = vae_regulation + + def set_model_denoise_mask_function(self, denoise_mask_function): + self.model_options["denoise_mask_function"] = denoise_mask_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): + self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index) + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) + + def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) + + def set_model_attn1_output_patch(self, patch): + self.set_model_patch(patch, "attn1_output_patch") + + def set_model_attn2_output_patch(self, patch): + self.set_model_patch(patch, "attn2_output_patch") + + def set_model_input_block_patch(self, patch): + self.set_model_patch(patch, "input_block_patch") + + def set_model_input_block_patch_after_skip(self, patch): + self.set_model_patch(patch, "input_block_patch_after_skip") + + def set_model_output_block_patch(self, patch): + self.set_model_patch(patch, "output_block_patch") + + def add_object_patch(self, name, obj): + self.object_patches[name] = obj + + def get_model_object(self, name): + if name in self.object_patches: + return self.object_patches[name] + else: + if name in self.object_patches_backup: + return self.object_patches_backup[name] + else: + return utils.get_attr(self.model, name) + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + patch_list[k] = patch_list[k].to(device) + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, "to"): + self.model_options["model_function_wrapper"] = wrap_func.to(device) + + def model_dtype(self): + if hasattr(self.model, "get_dtype"): + return self.model.get_dtype() + + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + p = set() + model_sd = self.model.state_dict() + for k in patches: + offset = None + function = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] + + if key in model_sd: + p.add(k) + current_patches = self.patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + self.patches[key] = current_patches + + return list(p) + + def get_key_patches(self, filter_prefix=None): + memory_management.unload_model_clones(self) + model_sd = self.model_state_dict() + p = {} + for k in model_sd: + if filter_prefix is not None: + if not k.startswith(filter_prefix): + continue + if k in self.patches: + p[k] = [model_sd[k]] + self.patches[k] + else: + p[k] = (model_sd[k],) + return p + + def model_state_dict(self, filter_prefix=None): + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd + + def patch_model(self, device_to=None, patch_weights=True): + for k in self.object_patches: + old = utils.get_attr(self.model, k) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old + utils.set_attr_raw(self.model, k, self.object_patches[k]) + + if patch_weights: + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", key) + continue + + weight = model_sd[key] + + inplace_update = self.weight_inplace_update + + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + utils.copy_to_param(self.model, key, out_weight) + else: + utils.set_attr(self.model, key, out_weight) + del temp_weight + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + + return self.model + + def calculate_weight(self, patches, weight, key): + for p in patches: + strength = p[0] + v = p[1] + strength_model = p[2] + offset = p[3] + function = p[4] + if function is None: + function = lambda a: a + + old_weight = None + if offset is not None: + old_weight = weight + weight = weight.narrow(offset[0], offset[1], offset[2]) + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (self.calculate_weight(v[1:], v[0].clone(), key),) + + patch_type = '' + + if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": + w1 = v[0] + if strength != 0.0: + if w1.shape != weight.shape: + if w1.ndim == weight.ndim == 4: + new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] + print(f'Merged with {key} channel changed to {new_shape}') + new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) + new_weight = torch.zeros(size=new_shape).to(weight) + new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight + new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff + new_weight = new_weight.contiguous().clone() + weight = new_weight + else: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) + elif patch_type == "lora": + mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32) + mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32) + dora_scale = v[4] + if v[2] is not None: + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + + if v[3] is not None: + mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + print("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "lokr": + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dora_scale = v[8] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32), + memory_management.cast_to_device(w1_b, weight.device, torch.float32)) + else: + w1 = memory_management.cast_to_device(w1, weight.device, torch.float32) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32), + memory_management.cast_to_device(w2_b, weight.device, torch.float32)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + memory_management.cast_to_device(t2, weight.device, torch.float32), + memory_management.cast_to_device(w2_b, weight.device, torch.float32), + memory_management.cast_to_device(w2_a, weight.device, torch.float32)) + else: + w2 = memory_management.cast_to_device(w2, weight.device, torch.float32) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha = v[2] / dim + else: + alpha = 1.0 + + try: + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + print("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "loha": + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + + w2a = v[3] + w2b = v[4] + dora_scale = v[7] + if v[5] is not None: + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + memory_management.cast_to_device(t1, weight.device, torch.float32), + memory_management.cast_to_device(w1b, weight.device, torch.float32), + memory_management.cast_to_device(w1a, weight.device, torch.float32)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + memory_management.cast_to_device(t2, weight.device, torch.float32), + memory_management.cast_to_device(w2b, weight.device, torch.float32), + memory_management.cast_to_device(w2a, weight.device, torch.float32)) + else: + m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32), + memory_management.cast_to_device(w1b, weight.device, torch.float32)) + m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32), + memory_management.cast_to_device(w2b, weight.device, torch.float32)) + + try: + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + print("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "glora": + if v[4] is not None: + alpha = v[4] / v[0].shape[0] + else: + alpha = 1.0 + + dora_scale = v[5] + + a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) + a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) + b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) + b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + + try: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + print("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type in extra_weight_calculators: + weight = extra_weight_calculators[patch_type](weight, strength, v) + else: + print("patch type not recognized {} {}".format(patch_type, key)) + + if old_weight is not None: + weight = old_weight + + return weight + + def unpatch_model(self, device_to=None): + keys = list(self.backup.keys()) + + if self.weight_inplace_update: + for k in keys: + utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + utils.set_attr(self.model, k, self.backup[k]) + + self.backup = {} + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to + + keys = list(self.object_patches_backup.keys()) + for k in keys: + utils.set_attr_raw(self.model, k, self.object_patches_backup[k]) + + self.object_patches_backup = {} diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py new file mode 100644 index 00000000..58e0e62f --- /dev/null +++ b/backend/patcher/lora.py @@ -0,0 +1,293 @@ +# LoRA Implementation Collection form ComfyUI +# Modified by Forge to support greedy loading (load a set or wrong/correct loras to a model and only preserve the correct ones), +# which is important to webui experience + +from backend.misc.diffusers_state_dict import unet_to_diffusers + + +LORA_CLIP_MAP = { + "mlp.fc1": "mlp_fc1", + "mlp.fc2": "mlp_fc2", + "self_attn.k_proj": "self_attn_k_proj", + "self_attn.q_proj": "self_attn_q_proj", + "self_attn.v_proj": "self_attn_v_proj", + "self_attn.out_proj": "self_attn_out_proj", +} + + +def load_lora(lora, to_load): + patch_dict = {} + loaded_keys = set() + for x in to_load: + alpha_name = "{}.alpha".format(x) + alpha = None + if alpha_name in lora.keys(): + alpha = lora[alpha_name].item() + loaded_keys.add(alpha_name) + + dora_scale_name = "{}.dora_scale".format(x) + dora_scale = None + if dora_scale_name in lora.keys(): + dora_scale = lora[dora_scale_name] + loaded_keys.add(dora_scale_name) + + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + diffusers2_lora = "{}.lora_B.weight".format(x) + diffusers3_lora = "{}.lora.up.weight".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif diffusers2_lora in lora.keys(): + A_name = diffusers2_lora + B_name = "{}.lora_A.weight".format(x) + mid_name = None + elif diffusers3_lora in lora.keys(): + A_name = diffusers3_lora + B_name = "{}.lora.down.weight".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name = "{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + + ######## loha + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + + ######## lokr + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) + + # glora + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) + + w_norm_name = "{}.w_norm".format(x) + b_norm_name = "{}.b_norm".format(x) + w_norm = lora.get(w_norm_name, None) + b_norm = lora.get(b_norm_name, None) + + if w_norm is not None: + loaded_keys.add(w_norm_name) + patch_dict[to_load[x]] = ("diff", (w_norm,)) + if b_norm is not None: + loaded_keys.add(b_norm_name) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) + + diff_name = "{}.diff".format(x) + diff_weight = lora.get(diff_name, None) + if diff_weight is not None: + patch_dict[to_load[x]] = ("diff", (diff_weight,)) + loaded_keys.add(diff_name) + + diff_bias_name = "{}.diff_b".format(x) + diff_bias = lora.get(diff_bias_name, None) + if diff_bias is not None: + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) + loaded_keys.add(diff_bias_name) + + remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} + return patch_dict, remaining_dict + + +def model_lora_keys_clip(model, key_map={}): + sdk = model.state_dict().keys() + + text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" + clip_l_present = False + for b in range(32): # TODO: clean up + for c in LORA_CLIP_MAP: + k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) + key_map[lora_key] = k + + k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + clip_l_present = True + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) + key_map[lora_key] = k + + k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + if k in sdk: + if clip_l_present: + lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) + key_map[lora_key] = k + else: + lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) + key_map[lora_key] = k + lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k + + for k in sdk: # OneTrainer SD3 lora + if k.startswith("t5xxl.transformer.") and k.endswith(".weight"): + l_key = k[len("t5xxl.transformer."):-len(".weight")] + lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) + key_map[lora_key] = k + + k = "clip_g.transformer.text_projection.weight" + if k in sdk: + key_map["lora_prior_te_text_projection"] = k + key_map["lora_te2_text_projection"] = k + + k = "clip_l.transformer.text_projection.weight" + if k in sdk: + key_map["lora_te1_text_projection"] = k + + return key_map + + +def model_lora_keys_unet(model, key_map={}): + sd = model.state_dict() + sdk = sd.keys() + + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = k + key_map["lora_prior_unet_{}".format(key_lora)] = k + + diffusers_keys = unet_to_diffusers(model.diffusion_model.legacy_config) + for k in diffusers_keys: + if k.endswith(".weight"): + unet_key = "diffusion_model.{}".format(diffusers_keys[k]) + key_lora = k[:-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = unet_key + + diffusers_lora_prefix = ["", "unet."] + for p in diffusers_lora_prefix: + diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) + if diffusers_lora_key.endswith(".to_out.0"): + diffusers_lora_key = diffusers_lora_key[:-2] + key_map[diffusers_lora_key] = unet_key + + # TODO: + + # if isinstance(model, xxxx.modules.model_base.SD3): # Diffusers lora SD3 + # diffusers_keys = xxxx.modules.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + # for k in diffusers_keys: + # if k.endswith(".weight"): + # to = diffusers_keys[k] + # key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format + # key_map[key_lora] = to + # + # key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others? + # key_map[key_lora] = to + # + # key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora + # key_map[key_lora] = to + # + # if isinstance(model, xxxx.modules.model_base.AuraFlow): # Diffusers lora AuraFlow + # diffusers_keys = xxxx.modules.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + # for k in diffusers_keys: + # if k.endswith(".weight"): + # to = diffusers_keys[k] + # key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format + # key_map[key_lora] = to + # + # if isinstance(model, xxxx.modules.model_base.HunyuanDiT): + # for k in sdk: + # if k.startswith("diffusion_model.") and k.endswith(".weight"): + # key_lora = k[len("diffusion_model."):-len(".weight")] + # key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format + + return key_map diff --git a/backend/utils.py b/backend/utils.py new file mode 100644 index 00000000..57b9276b --- /dev/null +++ b/backend/utils.py @@ -0,0 +1,30 @@ +import torch + + +def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) + + +def set_attr_raw(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + setattr(obj, attrs[-1], value) + + +def copy_to_param(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + prev.data.copy_(value) + + +def get_attr(obj, attr): + attrs = attr.split(".") + for name in attrs: + obj = getattr(obj, name) + return obj diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py deleted file mode 100644 index 5c235caf..00000000 --- a/extensions-builtin/Lora/lora_patches.py +++ /dev/null @@ -1,6 +0,0 @@ -class LoraPatches: - def __init__(self): - pass - - def undo(self): - pass diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py deleted file mode 100644 index 89987438..00000000 --- a/extensions-builtin/Lora/network.py +++ /dev/null @@ -1,228 +0,0 @@ -from __future__ import annotations -import os -from collections import namedtuple -import enum - -import torch.nn as nn -import torch.nn.functional as F - -from modules import sd_models, cache, errors, hashes, shared -import modules.models.sd3.mmdit - -NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) - -metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} - - -class SdVersion(enum.Enum): - Unknown = 1 - SD1 = 2 - SD2 = 3 - SDXL = 4 - - -class NetworkOnDisk: - def __init__(self, name, filename): - self.name = name - self.filename = filename - self.metadata = {} - self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - - def read_metadata(): - metadata = sd_models.read_metadata_from_safetensors(filename) - - return metadata - - if self.is_safetensors: - try: - self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) - except Exception as e: - errors.display(e, f"reading lora {filename}") - - if self.metadata: - m = {} - for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): - m[k] = v - - self.metadata = m - - self.alias = self.metadata.get('ss_output_name', self.name) - - self.hash = None - self.shorthash = None - self.set_hash( - self.metadata.get('sshs_model_hash') or - hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or - '' - ) - - self.sd_version = self.detect_version() - - def detect_version(self): - if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): - return SdVersion.SDXL - elif str(self.metadata.get('ss_v2', "")) == "True": - return SdVersion.SD2 - elif len(self.metadata): - return SdVersion.SD1 - - return SdVersion.Unknown - - def set_hash(self, v): - self.hash = v - self.shorthash = self.hash[0:12] - - if self.shorthash: - import networks - networks.available_network_hash_lookup[self.shorthash] = self - - def read_hash(self): - if not self.hash: - self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') - - def get_alias(self): - import networks - if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases: - return self.name - else: - return self.alias - - -class Network: # LoraModule - def __init__(self, name, network_on_disk: NetworkOnDisk): - self.name = name - self.network_on_disk = network_on_disk - self.te_multiplier = 1.0 - self.unet_multiplier = 1.0 - self.dyn_dim = None - self.modules = {} - self.bundle_embeddings = {} - self.mtime = None - - self.mentioned_name = None - """the text that was used to add the network to prompt - can be either name or an alias""" - - -class ModuleType: - def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: - return None - - -class NetworkModule: - def __init__(self, net: Network, weights: NetworkWeights): - self.network = net - self.network_key = weights.network_key - self.sd_key = weights.sd_key - self.sd_module = weights.sd_module - - if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear): - s = self.sd_module.weight.shape - self.shape = (s[0] // 3, s[1]) - elif hasattr(self.sd_module, 'weight'): - self.shape = self.sd_module.weight.shape - elif isinstance(self.sd_module, nn.MultiheadAttention): - # For now, only self-attn use Pytorch's MHA - # So assume all qkvo proj have same shape - self.shape = self.sd_module.out_proj.weight.shape - else: - self.shape = None - - self.ops = None - self.extra_kwargs = {} - if isinstance(self.sd_module, nn.Conv2d): - self.ops = F.conv2d - self.extra_kwargs = { - 'stride': self.sd_module.stride, - 'padding': self.sd_module.padding - } - elif isinstance(self.sd_module, nn.Linear): - self.ops = F.linear - elif isinstance(self.sd_module, nn.LayerNorm): - self.ops = F.layer_norm - self.extra_kwargs = { - 'normalized_shape': self.sd_module.normalized_shape, - 'eps': self.sd_module.eps - } - elif isinstance(self.sd_module, nn.GroupNorm): - self.ops = F.group_norm - self.extra_kwargs = { - 'num_groups': self.sd_module.num_groups, - 'eps': self.sd_module.eps - } - - self.dim = None - self.bias = weights.w.get("bias") - self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None - self.scale = weights.w["scale"].item() if "scale" in weights.w else None - - self.dora_scale = weights.w.get("dora_scale", None) - self.dora_norm_dims = len(self.shape) - 1 - - def multiplier(self): - if 'transformer' in self.sd_key[:20]: - return self.network.te_multiplier - else: - return self.network.unet_multiplier - - def calc_scale(self): - if self.scale is not None: - return self.scale - if self.dim is not None and self.alpha is not None: - return self.alpha / self.dim - - return 1.0 - - def apply_weight_decompose(self, updown, orig_weight): - # Match the device/dtype - orig_weight = orig_weight.to(updown.dtype) - dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype) - updown = updown.to(orig_weight.device) - - merged_scale1 = updown + orig_weight - merged_scale1_norm = ( - merged_scale1.transpose(0, 1) - .reshape(merged_scale1.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims) - .transpose(0, 1) - ) - - dora_merged = ( - merged_scale1 * (dora_scale / merged_scale1_norm) - ) - final_updown = dora_merged - orig_weight - return final_updown - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=updown.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - updown = updown * self.calc_scale() - - if self.dora_scale is not None: - updown = self.apply_weight_decompose(updown, orig_weight) - - return updown * self.multiplier(), ex_bias - - def calc_updown(self, target): - raise NotImplementedError() - - def forward(self, x, y): - """A general forward implementation for all modules""" - if self.ops is None: - raise NotImplementedError() - else: - updown, ex_bias = self.calc_updown(self.sd_module.weight) - return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs) - diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/sd_forge_lora/extra_networks_lora.py similarity index 97% rename from extensions-builtin/Lora/extra_networks_lora.py rename to extensions-builtin/sd_forge_lora/extra_networks_lora.py index 17a620f7..33edf465 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/sd_forge_lora/extra_networks_lora.py @@ -1,62 +1,62 @@ -from modules import extra_networks, shared -import networks - - -class ExtraNetworkLora(extra_networks.ExtraNetwork): - def __init__(self): - super().__init__('lora') - - self.errors = {} - """mapping of network names to the number of errors the network had during operation""" - - remove_symbols = str.maketrans('', '', ":,") - - def activate(self, p, params_list): - additional = shared.opts.sd_lora - - self.errors.clear() - - if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): - p.all_prompts = [x + f"" for x in p.all_prompts] - params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) - - names = [] - te_multipliers = [] - unet_multipliers = [] - dyn_dims = [] - for params in params_list: - assert params.items - - names.append(params.positional[0]) - - te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 - te_multiplier = float(params.named.get("te", te_multiplier)) - - unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier - unet_multiplier = float(params.named.get("unet", unet_multiplier)) - - dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None - dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim - - te_multipliers.append(te_multiplier) - unet_multipliers.append(unet_multiplier) - dyn_dims.append(dyn_dim) - - networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) - - if shared.opts.lora_add_hashes_to_infotext: - if not getattr(p, "is_hr_pass", False) or not hasattr(p, "lora_hashes"): - p.lora_hashes = {} - - for item in networks.loaded_networks: - if item.network_on_disk.shorthash and item.mentioned_name: - p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash - - if p.lora_hashes: - p.extra_generation_params["Lora hashes"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items()) - - def deactivate(self, p): - if self.errors: - p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items())) - - self.errors.clear() +from modules import extra_networks, shared +import networks + + +class ExtraNetworkLora(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('lora') + + self.errors = {} + """mapping of network names to the number of errors the network had during operation""" + + remove_symbols = str.maketrans('', '', ":,") + + def activate(self, p, params_list): + additional = shared.opts.sd_lora + + self.errors.clear() + + if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + + names = [] + te_multipliers = [] + unet_multipliers = [] + dyn_dims = [] + for params in params_list: + assert params.items + + names.append(params.positional[0]) + + te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 + te_multiplier = float(params.named.get("te", te_multiplier)) + + unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier + unet_multiplier = float(params.named.get("unet", unet_multiplier)) + + dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None + dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim + + te_multipliers.append(te_multiplier) + unet_multipliers.append(unet_multiplier) + dyn_dims.append(dyn_dim) + + networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) + + if shared.opts.lora_add_hashes_to_infotext: + if not getattr(p, "is_hr_pass", False) or not hasattr(p, "lora_hashes"): + p.lora_hashes = {} + + for item in networks.loaded_networks: + if item.network_on_disk.shorthash and item.mentioned_name: + p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash + + if p.lora_hashes: + p.extra_generation_params["Lora hashes"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items()) + + def deactivate(self, p): + if self.errors: + p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items())) + + self.errors.clear() diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/sd_forge_lora/lora.py similarity index 97% rename from extensions-builtin/Lora/lora.py rename to extensions-builtin/sd_forge_lora/lora.py index 9365aa74..6186538e 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/sd_forge_lora/lora.py @@ -1,9 +1,9 @@ -import networks - -list_available_loras = networks.list_available_networks - -available_loras = networks.available_networks -available_lora_aliases = networks.available_network_aliases -available_lora_hash_lookup = networks.available_network_hash_lookup -forbidden_lora_aliases = networks.forbidden_network_aliases -loaded_loras = networks.loaded_networks +import networks + +list_available_loras = networks.list_available_networks + +available_loras = networks.available_networks +available_lora_aliases = networks.available_network_aliases +available_lora_hash_lookup = networks.available_network_hash_lookup +forbidden_lora_aliases = networks.forbidden_network_aliases +loaded_loras = networks.loaded_networks diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/sd_forge_lora/lora_logger.py similarity index 100% rename from extensions-builtin/Lora/lora_logger.py rename to extensions-builtin/sd_forge_lora/lora_logger.py diff --git a/extensions-builtin/sd_forge_lora/network.py b/extensions-builtin/sd_forge_lora/network.py new file mode 100644 index 00000000..63de6562 --- /dev/null +++ b/extensions-builtin/sd_forge_lora/network.py @@ -0,0 +1,74 @@ +import os + +from modules import sd_models, cache, errors, hashes, shared + + +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} + + +class NetworkOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + self.metadata = {} + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + + def read_metadata(): + metadata = sd_models.read_metadata_from_safetensors(filename) + + return metadata + + if self.is_safetensors: + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading lora {filename}") + + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + + self.metadata = m + + self.alias = self.metadata.get('ss_output_name', self.name) + + self.hash = None + self.shorthash = None + self.set_hash( + self.metadata.get('sshs_model_hash') or + hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or + '' + ) + + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + + if self.shorthash: + import networks + networks.available_network_hash_lookup[self.shorthash] = self + + def read_hash(self): + if not self.hash: + self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') + + def get_alias(self): + import networks + if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases: + return self.name + else: + return self.alias + + +class Network: + def __init__(self, name, network_on_disk: NetworkOnDisk): + self.name = name + self.network_on_disk = network_on_disk + self.te_multiplier = 1.0 + self.unet_multiplier = 1.0 + self.dyn_dim = None + self.modules = {} + self.bundle_embeddings = {} + self.mtime = None + self.mentioned_name = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/sd_forge_lora/networks.py similarity index 70% rename from extensions-builtin/Lora/networks.py rename to extensions-builtin/sd_forge_lora/networks.py index 9ebeda6e..35b69f1c 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/sd_forge_lora/networks.py @@ -1,272 +1,171 @@ -from __future__ import annotations -import gradio as gr -import logging -import os -import re - -import lora_patches -import functools -import network - -import torch -from typing import Union - -from modules import shared, sd_models, errors, scripts -from ldm_patched.modules.utils import load_torch_file -from ldm_patched.modules.sd import load_lora_for_models - - -@functools.lru_cache(maxsize=5) -def load_lora_state_dict(filename): - return load_torch_file(filename, safe_load=True) - - -def convert_diffusers_name_to_compvis(key, is_sd2): - pass - - -def assign_network_names_to_compvis_modules(sd_model): - pass - - -class BundledTIHash(str): - def __init__(self, hash_str): - self.hash = hash_str - - def __str__(self): - return self.hash if shared.opts.lora_bundled_ti_to_infotext else '' - - -def load_network(name, network_on_disk): - net = network.Network(name, network_on_disk) - net.mtime = os.path.getmtime(network_on_disk.filename) - - return net - - -def purge_networks_from_memory(): - pass - - -def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): - global lora_state_dict_cache - - current_sd = sd_models.model_data.get_sd_model() - if current_sd is None: - return - - loaded_networks.clear() - - unavailable_networks = [] - for name in names: - if name.lower() in forbidden_network_aliases and available_networks.get(name) is None: - unavailable_networks.append(name) - elif available_network_aliases.get(name) is None: - unavailable_networks.append(name) - - if unavailable_networks: - update_available_networks_by_names(unavailable_networks) - - networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] - if any(x is None for x in networks_on_disk): - list_available_networks() - networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] - - for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): - try: - net = load_network(name, network_on_disk) - except Exception as e: - errors.display(e, f"loading network {network_on_disk.filename}") - continue - net.mentioned_name = name - network_on_disk.read_hash() - loaded_networks.append(net) - - compiled_lora_targets = [] - for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers): - compiled_lora_targets.append([a.filename, b, c]) - - compiled_lora_targets_hash = str(compiled_lora_targets) - - if current_sd.current_lora_hash == compiled_lora_targets_hash: - return - - current_sd.current_lora_hash = compiled_lora_targets_hash - current_sd.forge_objects.unet = current_sd.forge_objects_original.unet - current_sd.forge_objects.clip = current_sd.forge_objects_original.clip - - for filename, strength_model, strength_clip in compiled_lora_targets: - lora_sd = load_lora_state_dict(filename) - current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( - current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, - filename=filename) - - current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy() - return - - -def allowed_layer_without_weight(layer): - if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine: - return True - - return False - - -def store_weights_backup(weight): - if weight is None: - return None - - return weight.to(devices.cpu, copy=True) - - -def restore_weights_backup(obj, field, weight): - if weight is None: - setattr(obj, field, None) - return - - getattr(obj, field).copy_(weight) - - -def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): - pass - - -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): - pass - - -def network_forward(org_module, input, original_forward): - pass - - -def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): - pass - - -def network_Linear_forward(self, input): - pass - - -def network_Linear_load_state_dict(self, *args, **kwargs): - pass - - -def network_Conv2d_forward(self, input): - pass - - -def network_Conv2d_load_state_dict(self, *args, **kwargs): - pass - - -def network_GroupNorm_forward(self, input): - pass - - -def network_GroupNorm_load_state_dict(self, *args, **kwargs): - pass - - -def network_LayerNorm_forward(self, input): - pass - - -def network_LayerNorm_load_state_dict(self, *args, **kwargs): - pass - - -def network_MultiheadAttention_forward(self, *args, **kwargs): - pass - - -def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): - pass - - -def process_network_files(names: list[str] | None = None): - candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) - for filename in candidates: - if os.path.isdir(filename): - continue - name = os.path.splitext(os.path.basename(filename))[0] - # if names is provided, only load networks with names in the list - if names and name not in names: - continue - try: - entry = network.NetworkOnDisk(name, filename) - except OSError: # should catch FileNotFoundError and PermissionError etc. - errors.report(f"Failed to load network {name} from {filename}", exc_info=True) - continue - - available_networks[name] = entry - - if entry.alias in available_network_aliases: - forbidden_network_aliases[entry.alias.lower()] = 1 - - available_network_aliases[name] = entry - available_network_aliases[entry.alias] = entry - - -def update_available_networks_by_names(names: list[str]): - process_network_files(names) - - -def list_available_networks(): - available_networks.clear() - available_network_aliases.clear() - forbidden_network_aliases.clear() - available_network_hash_lookup.clear() - forbidden_network_aliases.update({"none": 1, "Addams": 1}) - - os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) - - process_network_files() - - -re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") - - -def infotext_pasted(infotext, params): - if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: - return # if the other extension is active, it will handle those fields, no need to do anything - - added = [] - - for k in params: - if not k.startswith("AddNet Model "): - continue - - num = k[13:] - - if params.get("AddNet Module " + num) != "LoRA": - continue - - name = params.get("AddNet Model " + num) - if name is None: - continue - - m = re_network_name.match(name) - if m: - name = m.group(1) - - multiplier = params.get("AddNet Weight A " + num, "1.0") - - added.append(f"") - - if added: - params["Prompt"] += "\n" + "".join(added) - - -originals: lora_patches.LoraPatches = None - -extra_network_lora = None - -available_networks = {} -available_network_aliases = {} -loaded_networks = [] -loaded_bundle_embeddings = {} -networks_in_memory = {} -available_network_hash_lookup = {} -forbidden_network_aliases = {} - -list_available_networks() +from __future__ import annotations +import gradio as gr +import logging +import os +import re + +import functools +import network + +import torch +from typing import Union + +from modules import shared, sd_models, errors, scripts +from ldm_patched.modules.utils import load_torch_file +from ldm_patched.modules.sd import load_lora_for_models + + +@functools.lru_cache(maxsize=5) +def load_lora_state_dict(filename): + return load_torch_file(filename, safe_load=True) + + +def load_network(name, network_on_disk): + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + + return net + + +def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + global lora_state_dict_cache + + current_sd = sd_models.model_data.get_sd_model() + if current_sd is None: + return + + loaded_networks.clear() + + unavailable_networks = [] + for name in names: + if name.lower() in forbidden_network_aliases and available_networks.get(name) is None: + unavailable_networks.append(name) + elif available_network_aliases.get(name) is None: + unavailable_networks.append(name) + + if unavailable_networks: + update_available_networks_by_names(unavailable_networks) + + networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] + if any(x is None for x in networks_on_disk): + list_available_networks() + networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] + + for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): + try: + net = load_network(name, network_on_disk) + except Exception as e: + errors.display(e, f"loading network {network_on_disk.filename}") + continue + net.mentioned_name = name + network_on_disk.read_hash() + loaded_networks.append(net) + + compiled_lora_targets = [] + for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers): + compiled_lora_targets.append([a.filename, b, c]) + + compiled_lora_targets_hash = str(compiled_lora_targets) + + if current_sd.current_lora_hash == compiled_lora_targets_hash: + return + + current_sd.current_lora_hash = compiled_lora_targets_hash + current_sd.forge_objects.unet = current_sd.forge_objects_original.unet + current_sd.forge_objects.clip = current_sd.forge_objects_original.clip + + for filename, strength_model, strength_clip in compiled_lora_targets: + lora_sd = load_lora_state_dict(filename) + current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( + current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, + filename=filename) + + current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy() + return + + +def process_network_files(names: list[str] | None = None): + candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + for filename in candidates: + if os.path.isdir(filename): + continue + name = os.path.splitext(os.path.basename(filename))[0] + # if names is provided, only load networks with names in the list + if names and name not in names: + continue + try: + entry = network.NetworkOnDisk(name, filename) + except OSError: # should catch FileNotFoundError and PermissionError etc. + errors.report(f"Failed to load network {name} from {filename}", exc_info=True) + continue + + available_networks[name] = entry + + if entry.alias in available_network_aliases: + forbidden_network_aliases[entry.alias.lower()] = 1 + + available_network_aliases[name] = entry + available_network_aliases[entry.alias] = entry + + +def update_available_networks_by_names(names: list[str]): + process_network_files(names) + + +def list_available_networks(): + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + process_network_files() + + +re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") + + +def infotext_pasted(infotext, params): + if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: + return # if the other extension is active, it will handle those fields, no need to do anything + + added = [] + + for k in params: + if not k.startswith("AddNet Model "): + continue + + num = k[13:] + + if params.get("AddNet Module " + num) != "LoRA": + continue + + name = params.get("AddNet Model " + num) + if name is None: + continue + + m = re_network_name.match(name) + if m: + name = m.group(1) + + multiplier = params.get("AddNet Weight A " + num, "1.0") + + added.append(f"") + + if added: + params["Prompt"] += "\n" + "".join(added) + + +extra_network_lora = None + +available_networks = {} +available_network_aliases = {} +loaded_networks = [] +loaded_bundle_embeddings = {} +networks_in_memory = {} +available_network_hash_lookup = {} +forbidden_network_aliases = {} + +list_available_networks() diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/sd_forge_lora/preload.py similarity index 98% rename from extensions-builtin/Lora/preload.py rename to extensions-builtin/sd_forge_lora/preload.py index 52fab29b..763f9421 100644 --- a/extensions-builtin/Lora/preload.py +++ b/extensions-builtin/sd_forge_lora/preload.py @@ -1,8 +1,8 @@ -import os -from modules import paths -from modules.paths_internal import normalized_filepath - - -def preload(parser): - parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) - parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS')) +import os +from modules import paths +from modules.paths_internal import normalized_filepath + + +def preload(parser): + parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) + parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS')) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/sd_forge_lora/scripts/lora_script.py similarity index 89% rename from extensions-builtin/Lora/scripts/lora_script.py rename to extensions-builtin/sd_forge_lora/scripts/lora_script.py index 9e9e4ad8..a8a26d51 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/sd_forge_lora/scripts/lora_script.py @@ -1,101 +1,90 @@ -import re - -import gradio as gr -from fastapi import FastAPI - -import network -import networks -import lora # noqa:F401 -import lora_patches -import extra_networks_lora -import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks, shared - - -def unload(): - networks.originals.undo() - - -def before_ui(): - ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) - - networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora() - extra_networks.register_extra_network(networks.extra_network_lora) - - -networks.originals = lora_patches.LoraPatches() - -script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) -script_callbacks.on_script_unloaded(unload) -script_callbacks.on_before_ui(before_ui) -script_callbacks.on_infotext_pasted(networks.infotext_pasted) - - -shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { - "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), - "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), - "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), - "lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'), - "lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), - "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), - "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), - "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), - "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), -})) - - -shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { - "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), -})) - - -def create_lora_json(obj: network.NetworkOnDisk): - return { - "name": obj.name, - "alias": obj.alias, - "path": obj.filename, - "metadata": obj.metadata, - } - - -def api_networks(_: gr.Blocks, app: FastAPI): - @app.get("/sdapi/v1/loras") - async def get_loras(): - return [create_lora_json(obj) for obj in networks.available_networks.values()] - - @app.post("/sdapi/v1/refresh-loras") - async def refresh_loras(): - return networks.list_available_networks() - - -script_callbacks.on_app_started(api_networks) - -re_lora = re.compile("= 16 - - -re_word = re.compile(r"[-_\w']+") -re_comma = re.compile(r" *, *") - - -def build_tags(metadata): - tags = {} - - ss_tag_frequency = metadata.get("ss_tag_frequency", {}) - if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'): - for _, tags_dict in ss_tag_frequency.items(): - for tag, tag_count in tags_dict.items(): - tag = tag.strip() - tags[tag] = tags.get(tag, 0) + int(tag_count) - - if tags and is_non_comma_tagset(tags): - new_tags = {} - - for text, text_count in tags.items(): - for word in re.findall(re_word, text): - if len(word) < 3: - continue - - new_tags[word] = new_tags.get(word, 0) + text_count - - tags = new_tags - - ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True) - - return [(tag, tags[tag]) for tag in ordered_tags] - - -class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): - def __init__(self, ui, tabname, page): - super().__init__(ui, tabname, page) - - self.select_sd_version = None - - self.taginfo = None - self.edit_activation_text = None - self.slider_preferred_weight = None - self.edit_notes = None - - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes): - user_metadata = self.get_user_metadata(name) - user_metadata["description"] = desc - user_metadata["sd version"] = sd_version - user_metadata["activation text"] = activation_text - user_metadata["preferred weight"] = preferred_weight - user_metadata["negative text"] = negative_text - user_metadata["notes"] = notes - - self.write_user_metadata(name, user_metadata) - - def get_metadata_table(self, name): - table = super().get_metadata_table(name) - item = self.page.items.get(name, {}) - metadata = item.get("metadata") or {} - - keys = { - 'ss_output_name': "Output name:", - 'ss_sd_model_name': "Model:", - 'ss_clip_skip': "Clip skip:", - 'ss_network_module': "Kohya module:", - } - - for key, label in keys.items(): - value = metadata.get(key, None) - if value is not None and str(value) != "None": - table.append((label, html.escape(value))) - - ss_training_started_at = metadata.get('ss_training_started_at') - if ss_training_started_at: - table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M'))) - - ss_bucket_info = metadata.get("ss_bucket_info") - if ss_bucket_info and "buckets" in ss_bucket_info: - resolutions = {} - for _, bucket in ss_bucket_info["buckets"].items(): - resolution = bucket["resolution"] - resolution = f'{resolution[1]}x{resolution[0]}' - - resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"]) - - resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True) - resolutions_text = html.escape(", ".join(resolutions_list[0:4])) - if len(resolutions) > 4: - resolutions_text += ", ..." - resolutions_text = f"{resolutions_text}" - - table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text)) - - image_count = 0 - for _, params in metadata.get("ss_dataset_dirs", {}).items(): - image_count += int(params.get("img_count", 0)) - - if image_count: - table.append(("Dataset size:", image_count)) - - return table - - def put_values_into_components(self, name): - user_metadata = self.get_user_metadata(name) - values = super().put_values_into_components(name) - - item = self.page.items.get(name, {}) - metadata = item.get("metadata") or {} - - tags = build_tags(metadata) - gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] - - return [ - *values[0:5], - item.get("sd_version", "Unknown"), - gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), - user_metadata.get('activation text', ''), - float(user_metadata.get('preferred weight', 0.0)), - user_metadata.get('negative text', ''), - gr.update(visible=True if tags else False), - gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), - ] - - def generate_random_prompt(self, name): - item = self.page.items.get(name, {}) - metadata = item.get("metadata") or {} - tags = build_tags(metadata) - - return self.generate_random_prompt_from_tags(tags) - - def generate_random_prompt_from_tags(self, tags): - max_count = None - res = [] - for tag, count in tags: - if not max_count: - max_count = count - - v = random.random() * max_count - if count > v: - for x in "({[]})": - tag = tag.replace(x, '\\' + x) - res.append(tag) - - return ", ".join(sorted(res)) - - def create_extra_default_items_in_left_column(self): - - # this would be a lot better as gr.Radio but I can't make it work - self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) - - def create_editor(self): - self.create_default_editor_elems() - - self.taginfo = gr.HighlightedText(label="Training dataset tags") - self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") - self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) - self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") - with gr.Row() as row_random_prompt: - with gr.Column(scale=8): - random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) - - with gr.Column(scale=1, min_width=120): - generate_random_prompt = gr.Button('Generate', size="lg", scale=1) - - self.edit_notes = gr.TextArea(label='Notes', lines=4) - - generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False) - - def select_tag(activation_text, evt: gr.SelectData): - tag = evt.value[0] - - words = re.split(re_comma, activation_text) - if tag in words: - words = [x for x in words if x != tag and x.strip()] - return ", ".join(words) - - return activation_text + ", " + tag if activation_text else tag - - self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False) - - self.create_default_buttons() - - viewed_components = [ - self.edit_name, - self.edit_description, - self.html_filedata, - self.html_preview, - self.edit_notes, - self.select_sd_version, - self.taginfo, - self.edit_activation_text, - self.slider_preferred_weight, - self.edit_negative_text, - row_random_prompt, - random_prompt, - ] - - self.button_edit\ - .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ - .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) - - edited_components = [ - self.edit_description, - self.select_sd_version, - self.edit_activation_text, - self.slider_preferred_weight, - self.edit_negative_text, - self.edit_notes, - ] - - - self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) +import datetime +import html +import random + +import gradio as gr +import re + +from modules import ui_extra_networks_user_metadata + + +def is_non_comma_tagset(tags): + average_tag_length = sum(len(x) for x in tags.keys()) / len(tags) + + return average_tag_length >= 16 + + +re_word = re.compile(r"[-_\w']+") +re_comma = re.compile(r" *, *") + + +def build_tags(metadata): + tags = {} + + ss_tag_frequency = metadata.get("ss_tag_frequency", {}) + if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'): + for _, tags_dict in ss_tag_frequency.items(): + for tag, tag_count in tags_dict.items(): + tag = tag.strip() + tags[tag] = tags.get(tag, 0) + int(tag_count) + + if tags and is_non_comma_tagset(tags): + new_tags = {} + + for text, text_count in tags.items(): + for word in re.findall(re_word, text): + if len(word) < 3: + continue + + new_tags[word] = new_tags.get(word, 0) + text_count + + tags = new_tags + + ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True) + + return [(tag, tags[tag]) for tag in ordered_tags] + + +class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): + def __init__(self, ui, tabname, page): + super().__init__(ui, tabname, page) + + self.select_sd_version = None + + self.taginfo = None + self.edit_activation_text = None + self.slider_preferred_weight = None + self.edit_notes = None + + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["sd version"] = sd_version + user_metadata["activation text"] = activation_text + user_metadata["preferred weight"] = preferred_weight + user_metadata["negative text"] = negative_text + user_metadata["notes"] = notes + + self.write_user_metadata(name, user_metadata) + + def get_metadata_table(self, name): + table = super().get_metadata_table(name) + item = self.page.items.get(name, {}) + metadata = item.get("metadata") or {} + + keys = { + 'ss_output_name': "Output name:", + 'ss_sd_model_name': "Model:", + 'ss_clip_skip': "Clip skip:", + 'ss_network_module': "Kohya module:", + } + + for key, label in keys.items(): + value = metadata.get(key, None) + if value is not None and str(value) != "None": + table.append((label, html.escape(value))) + + ss_training_started_at = metadata.get('ss_training_started_at') + if ss_training_started_at: + table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M'))) + + ss_bucket_info = metadata.get("ss_bucket_info") + if ss_bucket_info and "buckets" in ss_bucket_info: + resolutions = {} + for _, bucket in ss_bucket_info["buckets"].items(): + resolution = bucket["resolution"] + resolution = f'{resolution[1]}x{resolution[0]}' + + resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"]) + + resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True) + resolutions_text = html.escape(", ".join(resolutions_list[0:4])) + if len(resolutions) > 4: + resolutions_text += ", ..." + resolutions_text = f"{resolutions_text}" + + table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text)) + + image_count = 0 + for _, params in metadata.get("ss_dataset_dirs", {}).items(): + image_count += int(params.get("img_count", 0)) + + if image_count: + table.append(("Dataset size:", image_count)) + + return table + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + values = super().put_values_into_components(name) + + item = self.page.items.get(name, {}) + metadata = item.get("metadata") or {} + + tags = build_tags(metadata) + gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] + + return [ + *values[0:5], + item.get("sd_version", "Unknown"), + gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), + user_metadata.get('activation text', ''), + float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('negative text', ''), + gr.update(visible=True if tags else False), + gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), + ] + + def generate_random_prompt(self, name): + item = self.page.items.get(name, {}) + metadata = item.get("metadata") or {} + tags = build_tags(metadata) + + return self.generate_random_prompt_from_tags(tags) + + def generate_random_prompt_from_tags(self, tags): + max_count = None + res = [] + for tag, count in tags: + if not max_count: + max_count = count + + v = random.random() * max_count + if count > v: + for x in "({[]})": + tag = tag.replace(x, '\\' + x) + res.append(tag) + + return ", ".join(sorted(res)) + + def create_extra_default_items_in_left_column(self): + + # this would be a lot better as gr.Radio but I can't make it work + self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) + + def create_editor(self): + self.create_default_editor_elems() + + self.taginfo = gr.HighlightedText(label="Training dataset tags") + self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") + self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) + self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") + with gr.Row() as row_random_prompt: + with gr.Column(scale=8): + random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) + + with gr.Column(scale=1, min_width=120): + generate_random_prompt = gr.Button('Generate', size="lg", scale=1) + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False) + + def select_tag(activation_text, evt: gr.SelectData): + tag = evt.value[0] + + words = re.split(re_comma, activation_text) + if tag in words: + words = [x for x in words if x != tag and x.strip()] + return ", ".join(words) + + return activation_text + ", " + tag if activation_text else tag + + self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False) + + self.create_default_buttons() + + viewed_components = [ + self.edit_name, + self.edit_description, + self.html_filedata, + self.html_preview, + self.edit_notes, + self.select_sd_version, + self.taginfo, + self.edit_activation_text, + self.slider_preferred_weight, + self.edit_negative_text, + row_random_prompt, + random_prompt, + ] + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + edited_components = [ + self.edit_description, + self.select_sd_version, + self.edit_activation_text, + self.slider_preferred_weight, + self.edit_negative_text, + self.edit_notes, + ] + + + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py similarity index 69% rename from extensions-builtin/Lora/ui_extra_networks_lora.py rename to extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py index 35e71be3..f4701fee 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py @@ -1,90 +1,69 @@ -import os - -import network -import networks - -from modules import shared, ui_extra_networks -from modules.ui_extra_networks import quote_js -from ui_edit_user_metadata import LoraUserMetadataEditor - - -class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): - def __init__(self): - super().__init__('Lora') - - def refresh(self): - networks.list_available_networks() - - def create_item(self, name, index=None, enable_filter=True): - lora_on_disk = networks.available_networks.get(name) - if lora_on_disk is None: - return - - path, ext = os.path.splitext(lora_on_disk.filename) - - alias = lora_on_disk.get_alias() - - search_terms = [self.search_terms_from_path(lora_on_disk.filename)] - if lora_on_disk.hash: - search_terms.append(lora_on_disk.hash) - item = { - "name": name, - "filename": lora_on_disk.filename, - "shorthash": lora_on_disk.shorthash, - "preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata), - "description": self.find_description(path), - "search_terms": search_terms, - "local_preview": f"{path}.{shared.opts.samples_format}", - "metadata": lora_on_disk.metadata, - "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, - "sd_version": lora_on_disk.sd_version.name, - } - - self.read_user_metadata(item) - activation_text = item["user_metadata"].get("activation text") - preferred_weight = item["user_metadata"].get("preferred weight", 0.0) - item["prompt"] = quote_js(f"") - - if activation_text: - item["prompt"] += " + " + quote_js(" " + activation_text) - - negative_prompt = item["user_metadata"].get("negative text") - item["negative_prompt"] = quote_js("") - if negative_prompt: - item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') - - sd_version = item["user_metadata"].get("sd version") - if sd_version in network.SdVersion.__members__: - item["sd_version"] = sd_version - sd_version = network.SdVersion[sd_version] - else: - sd_version = lora_on_disk.sd_version - - if shared.opts.lora_filter_disabled or not enable_filter or not shared.sd_model: - pass - elif sd_version == network.SdVersion.Unknown: - model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1 - if model_version.name in shared.opts.lora_hide_unknown_for_versions: - return None - elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL: - return None - elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2: - return None - elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1: - return None - - return item - - def list_items(self): - # instantiate a list to protect against concurrent modification - names = list(networks.available_networks) - for index, name in enumerate(names): - item = self.create_item(name, index) - if item is not None: - yield item - - def allowed_directories_for_previews(self): - return [shared.cmd_opts.lora_dir] - - def create_user_metadata_editor(self, ui, tabname): - return LoraUserMetadataEditor(ui, tabname, self) +import os + +import network +import networks + +from modules import shared, ui_extra_networks +from modules.ui_extra_networks import quote_js +from ui_edit_user_metadata import LoraUserMetadataEditor + + +class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Lora') + + def refresh(self): + networks.list_available_networks() + + def create_item(self, name, index=None, enable_filter=True): + lora_on_disk = networks.available_networks.get(name) + if lora_on_disk is None: + return + + path, ext = os.path.splitext(lora_on_disk.filename) + + alias = lora_on_disk.get_alias() + + search_terms = [self.search_terms_from_path(lora_on_disk.filename)] + if lora_on_disk.hash: + search_terms.append(lora_on_disk.hash) + item = { + "name": name, + "filename": lora_on_disk.filename, + "shorthash": lora_on_disk.shorthash, + "preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata), + "description": self.find_description(path), + "search_terms": search_terms, + "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": lora_on_disk.metadata, + "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, + } + + self.read_user_metadata(item) + activation_text = item["user_metadata"].get("activation text") + preferred_weight = item["user_metadata"].get("preferred weight", 0.0) + item["prompt"] = quote_js(f"") + + if activation_text: + item["prompt"] += " + " + quote_js(" " + activation_text) + + negative_prompt = item["user_metadata"].get("negative text") + item["negative_prompt"] = quote_js("") + if negative_prompt: + item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') + + return item + + def list_items(self): + # instantiate a list to protect against concurrent modification + names = list(networks.available_networks) + for index, name in enumerate(names): + item = self.create_item(name, index) + if item is not None: + yield item + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.lora_dir] + + def create_user_metadata_editor(self, ui, tabname): + return LoraUserMetadataEditor(ui, tabname, self) diff --git a/ldm_patched/modules/lora.py b/ldm_patched/modules/lora.py index b51898f5..6bdbb3d0 100644 --- a/ldm_patched/modules/lora.py +++ b/ldm_patched/modules/lora.py @@ -1,226 +1 @@ -# 1st edit by https://github.com/comfyanonymous/ComfyUI -# 2nd edit by Forge Official - - -import ldm_patched.modules.utils - -LORA_CLIP_MAP = { - "mlp.fc1": "mlp_fc1", - "mlp.fc2": "mlp_fc2", - "self_attn.k_proj": "self_attn_k_proj", - "self_attn.q_proj": "self_attn_q_proj", - "self_attn.v_proj": "self_attn_v_proj", - "self_attn.out_proj": "self_attn_out_proj", -} - - -def load_lora(lora, to_load): - patch_dict = {} - loaded_keys = set() - for x in to_load: - alpha_name = "{}.alpha".format(x) - alpha = None - if alpha_name in lora.keys(): - alpha = lora[alpha_name].item() - loaded_keys.add(alpha_name) - - regular_lora = "{}.lora_up.weight".format(x) - diffusers_lora = "{}_lora.up.weight".format(x) - transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None - - if regular_lora in lora.keys(): - A_name = regular_lora - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) - elif diffusers_lora in lora.keys(): - A_name = diffusers_lora - B_name = "{}_lora.down.weight".format(x) - mid_name = None - elif transformers_lora in lora.keys(): - A_name = transformers_lora - B_name ="{}.lora_linear_layer.down.weight".format(x) - mid_name = None - - if A_name is not None: - mid = None - if mid_name is not None and mid_name in lora.keys(): - mid = lora[mid_name] - loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) - loaded_keys.add(A_name) - loaded_keys.add(B_name) - - - ######## loha - hada_w1_a_name = "{}.hada_w1_a".format(x) - hada_w1_b_name = "{}.hada_w1_b".format(x) - hada_w2_a_name = "{}.hada_w2_a".format(x) - hada_w2_b_name = "{}.hada_w2_b".format(x) - hada_t1_name = "{}.hada_t1".format(x) - hada_t2_name = "{}.hada_t2".format(x) - if hada_w1_a_name in lora.keys(): - hada_t1 = None - hada_t2 = None - if hada_t1_name in lora.keys(): - hada_t1 = lora[hada_t1_name] - hada_t2 = lora[hada_t2_name] - loaded_keys.add(hada_t1_name) - loaded_keys.add(hada_t2_name) - - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) - loaded_keys.add(hada_w1_a_name) - loaded_keys.add(hada_w1_b_name) - loaded_keys.add(hada_w2_a_name) - loaded_keys.add(hada_w2_b_name) - - - ######## lokr - lokr_w1_name = "{}.lokr_w1".format(x) - lokr_w2_name = "{}.lokr_w2".format(x) - lokr_w1_a_name = "{}.lokr_w1_a".format(x) - lokr_w1_b_name = "{}.lokr_w1_b".format(x) - lokr_t2_name = "{}.lokr_t2".format(x) - lokr_w2_a_name = "{}.lokr_w2_a".format(x) - lokr_w2_b_name = "{}.lokr_w2_b".format(x) - - lokr_w1 = None - if lokr_w1_name in lora.keys(): - lokr_w1 = lora[lokr_w1_name] - loaded_keys.add(lokr_w1_name) - - lokr_w2 = None - if lokr_w2_name in lora.keys(): - lokr_w2 = lora[lokr_w2_name] - loaded_keys.add(lokr_w2_name) - - lokr_w1_a = None - if lokr_w1_a_name in lora.keys(): - lokr_w1_a = lora[lokr_w1_a_name] - loaded_keys.add(lokr_w1_a_name) - - lokr_w1_b = None - if lokr_w1_b_name in lora.keys(): - lokr_w1_b = lora[lokr_w1_b_name] - loaded_keys.add(lokr_w1_b_name) - - lokr_w2_a = None - if lokr_w2_a_name in lora.keys(): - lokr_w2_a = lora[lokr_w2_a_name] - loaded_keys.add(lokr_w2_a_name) - - lokr_w2_b = None - if lokr_w2_b_name in lora.keys(): - lokr_w2_b = lora[lokr_w2_b_name] - loaded_keys.add(lokr_w2_b_name) - - lokr_t2 = None - if lokr_t2_name in lora.keys(): - lokr_t2 = lora[lokr_t2_name] - loaded_keys.add(lokr_t2_name) - - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) - - #glora - a1_name = "{}.a1.weight".format(x) - a2_name = "{}.a2.weight".format(x) - b1_name = "{}.b1.weight".format(x) - b2_name = "{}.b2.weight".format(x) - if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) - loaded_keys.add(a1_name) - loaded_keys.add(a2_name) - loaded_keys.add(b1_name) - loaded_keys.add(b2_name) - - w_norm_name = "{}.w_norm".format(x) - b_norm_name = "{}.b_norm".format(x) - w_norm = lora.get(w_norm_name, None) - b_norm = lora.get(b_norm_name, None) - - if w_norm is not None: - loaded_keys.add(w_norm_name) - patch_dict[to_load[x]] = ("diff", (w_norm,)) - if b_norm is not None: - loaded_keys.add(b_norm_name) - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) - - diff_name = "{}.diff".format(x) - diff_weight = lora.get(diff_name, None) - if diff_weight is not None: - patch_dict[to_load[x]] = ("diff", (diff_weight,)) - loaded_keys.add(diff_name) - - diff_bias_name = "{}.diff_b".format(x) - diff_bias = lora.get(diff_bias_name, None) - if diff_bias is not None: - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) - loaded_keys.add(diff_bias_name) - - remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} - return patch_dict, remaining_dict - -def model_lora_keys_clip(model, key_map={}): - sdk = model.state_dict().keys() - - text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" - clip_l_present = False - for b in range(32): #TODO: clean up - for c in LORA_CLIP_MAP: - k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) - key_map[lora_key] = k - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base - key_map[lora_key] = k - clip_l_present = True - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) - if k in sdk: - if clip_l_present: - lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base - key_map[lora_key] = k - lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - else: - lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner - key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora - key_map[lora_key] = k - - return key_map - -def model_lora_keys_unet(model, key_map={}): - sdk = model.state_dict().keys() - - for k in sdk: - if k.startswith("diffusion_model.") and k.endswith(".weight"): - key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = k - - diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.diffusion_model.legacy_config) - for k in diffusers_keys: - if k.endswith(".weight"): - unet_key = "diffusion_model.{}".format(diffusers_keys[k]) - key_lora = k[:-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = unet_key - - diffusers_lora_prefix = ["", "unet."] - for p in diffusers_lora_prefix: - diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) - if diffusers_lora_key.endswith(".to_out.0"): - diffusers_lora_key = diffusers_lora_key[:-2] - key_map[diffusers_lora_key] = unet_key - return key_map +from backend.patcher.lora import * diff --git a/ldm_patched/modules/model_patcher.py b/ldm_patched/modules/model_patcher.py index 02d99296..4b4fb30e 100644 --- a/ldm_patched/modules/model_patcher.py +++ b/ldm_patched/modules/model_patcher.py @@ -1,386 +1 @@ -# 1st edit by https://github.com/comfyanonymous/ComfyUI -# 2nd edit by Forge Official - - -import torch -import copy -import inspect - -import ldm_patched.modules.utils -import ldm_patched.modules.model_management - - -extra_weight_calculators = {} - - -class ModelPatcher: - def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): - self.size = size - self.model = model - self.patches = {} - self.backup = {} - self.object_patches = {} - self.object_patches_backup = {} - self.model_options = {"transformer_options":{}} - self.model_size() - self.load_device = load_device - self.offload_device = offload_device - if current_device is None: - self.current_device = self.offload_device - else: - self.current_device = current_device - - self.weight_inplace_update = weight_inplace_update - - def model_size(self): - if self.size > 0: - return self.size - model_sd = self.model.state_dict() - self.size = ldm_patched.modules.model_management.module_size(self.model) - self.model_keys = set(model_sd.keys()) - return self.size - - def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) - n.patches = {} - for k in self.patches: - n.patches[k] = self.patches[k][:] - - n.object_patches = self.object_patches.copy() - n.model_options = copy.deepcopy(self.model_options) - n.model_keys = self.model_keys - return n - - def is_clone(self, other): - if hasattr(other, 'model') and self.model is other.model: - return True - return False - - def memory_required(self, input_shape): - return self.model.memory_required(input_shape=input_shape) - - def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): - if len(inspect.signature(sampler_cfg_function).parameters) == 3: - self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way - else: - self.model_options["sampler_cfg_function"] = sampler_cfg_function - if disable_cfg1_optimization: - self.model_options["disable_cfg1_optimization"] = True - - def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): - self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] - if disable_cfg1_optimization: - self.model_options["disable_cfg1_optimization"] = True - - def set_model_unet_function_wrapper(self, unet_wrapper_function): - self.model_options["model_function_wrapper"] = unet_wrapper_function - - def set_model_vae_encode_wrapper(self, wrapper_function): - self.model_options["model_vae_encode_wrapper"] = wrapper_function - - def set_model_vae_decode_wrapper(self, wrapper_function): - self.model_options["model_vae_decode_wrapper"] = wrapper_function - - def set_model_vae_regulation(self, vae_regulation): - self.model_options["model_vae_regulation"] = vae_regulation - - def set_model_patch(self, patch, name): - to = self.model_options["transformer_options"] - if "patches" not in to: - to["patches"] = {} - to["patches"][name] = to["patches"].get(name, []) + [patch] - - def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): - to = self.model_options["transformer_options"] - if "patches_replace" not in to: - to["patches_replace"] = {} - if name not in to["patches_replace"]: - to["patches_replace"][name] = {} - if transformer_index is not None: - block = (block_name, number, transformer_index) - else: - block = (block_name, number) - to["patches_replace"][name][block] = patch - - def set_model_attn1_patch(self, patch): - self.set_model_patch(patch, "attn1_patch") - - def set_model_attn2_patch(self, patch): - self.set_model_patch(patch, "attn2_patch") - - def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): - self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) - - def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): - self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) - - def set_model_attn1_output_patch(self, patch): - self.set_model_patch(patch, "attn1_output_patch") - - def set_model_attn2_output_patch(self, patch): - self.set_model_patch(patch, "attn2_output_patch") - - def set_model_input_block_patch(self, patch): - self.set_model_patch(patch, "input_block_patch") - - def set_model_input_block_patch_after_skip(self, patch): - self.set_model_patch(patch, "input_block_patch_after_skip") - - def set_model_output_block_patch(self, patch): - self.set_model_patch(patch, "output_block_patch") - - def add_object_patch(self, name, obj): - self.object_patches[name] = obj - - def model_patches_to(self, device): - to = self.model_options["transformer_options"] - if "patches" in to: - patches = to["patches"] - for name in patches: - patch_list = patches[name] - for i in range(len(patch_list)): - if hasattr(patch_list[i], "to"): - patch_list[i] = patch_list[i].to(device) - if "patches_replace" in to: - patches = to["patches_replace"] - for name in patches: - patch_list = patches[name] - for k in patch_list: - if hasattr(patch_list[k], "to"): - patch_list[k] = patch_list[k].to(device) - if "model_function_wrapper" in self.model_options: - wrap_func = self.model_options["model_function_wrapper"] - if hasattr(wrap_func, "to"): - self.model_options["model_function_wrapper"] = wrap_func.to(device) - - def model_dtype(self): - if hasattr(self.model, "get_dtype"): - return self.model.get_dtype() - - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - for k in patches: - if k in self.model_keys: - p.add(k) - current_patches = self.patches.get(k, []) - current_patches.append((strength_patch, patches[k], strength_model)) - self.patches[k] = current_patches - - return list(p) - - def get_key_patches(self, filter_prefix=None): - ldm_patched.modules.model_management.unload_model_clones(self) - model_sd = self.model_state_dict() - p = {} - for k in model_sd: - if filter_prefix is not None: - if not k.startswith(filter_prefix): - continue - if k in self.patches: - p[k] = [model_sd[k]] + self.patches[k] - else: - p[k] = (model_sd[k],) - return p - - def model_state_dict(self, filter_prefix=None): - sd = self.model.state_dict() - keys = list(sd.keys()) - if filter_prefix is not None: - for k in keys: - if not k.startswith(filter_prefix): - sd.pop(k) - return sd - - def patch_model(self, device_to=None, patch_weights=True): - for k in self.object_patches: - old = ldm_patched.modules.utils.get_attr(self.model, k) - if k not in self.object_patches_backup: - self.object_patches_backup[k] = old - ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches[k]) - - if patch_weights: - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) - continue - - weight = model_sd[key] - - inplace_update = self.weight_inplace_update - - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) - - if device_to is not None: - temp_weight = ldm_patched.modules.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - ldm_patched.modules.utils.copy_to_param(self.model, key, out_weight) - else: - ldm_patched.modules.utils.set_attr(self.model, key, out_weight) - del temp_weight - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to - - return self.model - - def calculate_weight(self, patches, weight, key): - for p in patches: - alpha = p[0] - v = p[1] - strength_model = p[2] - - if strength_model != 1.0: - weight *= strength_model - - if isinstance(v, list): - v = (self.calculate_weight(v[1:], v[0].clone(), key), ) - - if len(v) == 1: - patch_type = "diff" - elif len(v) == 2: - patch_type = v[0] - v = v[1] - - if patch_type == "diff": - w1 = v[0] - if alpha != 0.0: - if w1.shape != weight.shape: - if w1.ndim == weight.ndim == 4: - new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] - print(f'Merged with {key} channel changed to {new_shape}') - new_diff = alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype) - new_weight = torch.zeros(size=new_shape).to(weight) - new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight - new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff - new_weight = new_weight.contiguous().clone() - weight = new_weight - else: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) - else: - weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype) - elif patch_type == "lora": #lora/locon - mat1 = ldm_patched.modules.model_management.cast_to_device(v[0], weight.device, torch.float32) - mat2 = ldm_patched.modules.model_management.cast_to_device(v[1], weight.device, torch.float32) - if v[2] is not None: - alpha *= v[2] / mat2.shape[0] - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = ldm_patched.modules.model_management.cast_to_device(v[3], weight.device, torch.float32) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - elif patch_type == "lokr": - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1_a, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w1_b, weight.device, torch.float32)) - else: - w1 = ldm_patched.modules.model_management.cast_to_device(w1, weight.device, torch.float32) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w2_b, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w2_a, weight.device, torch.float32)) - else: - w2 = ldm_patched.modules.model_management.cast_to_device(w2, weight.device, torch.float32) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha *= v[2] / dim - - try: - weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - elif patch_type == "loha": - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha *= v[2] / w1b.shape[0] - w2a = v[3] - w2b = v[4] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - ldm_patched.modules.model_management.cast_to_device(t1, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32)) - - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - ldm_patched.modules.model_management.cast_to_device(t2, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32)) - else: - m1 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w1a, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w1b, weight.device, torch.float32)) - m2 = torch.mm(ldm_patched.modules.model_management.cast_to_device(w2a, weight.device, torch.float32), - ldm_patched.modules.model_management.cast_to_device(w2b, weight.device, torch.float32)) - - try: - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) - except Exception as e: - print("ERROR", key, e) - elif patch_type == "glora": - if v[4] is not None: - alpha *= v[4] / v[0].shape[0] - - a1 = ldm_patched.modules.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) - a2 = ldm_patched.modules.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) - b1 = ldm_patched.modules.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) - b2 = ldm_patched.modules.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) - - weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) - elif patch_type in extra_weight_calculators: - weight = extra_weight_calculators[patch_type](weight, alpha, v) - else: - print("patch type not recognized", patch_type, key) - - return weight - - def unpatch_model(self, device_to=None): - keys = list(self.backup.keys()) - - if self.weight_inplace_update: - for k in keys: - ldm_patched.modules.utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - ldm_patched.modules.utils.set_attr(self.model, k, self.backup[k]) - - self.backup = {} - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to - - keys = list(self.object_patches_backup.keys()) - for k in keys: - ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches_backup[k]) - - self.object_patches_backup = {} +from backend.patcher.base import * diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index af62b0a8..dd0f0925 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -25,7 +25,6 @@ class UnetPatcher(ModelPatcher): n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) - n.model_keys = self.model_keys n.controlnet_linked_list = self.controlnet_linked_list n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()