diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 57762424..a07ffc5f 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -129,10 +129,15 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) + + elif patch_type == "set": + weight.copy_(v[0]) + elif patch_type == "lora": mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype) mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype) dora_scale = v[4] + if v[2] is not None: alpha = v[2] / mat2.shape[0] else: @@ -142,12 +147,26 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: - lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) + lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)) + + try: + lora_diff = lora_diff.reshape(weight.shape) + except: + if weight.shape[1] < lora_diff.shape[1]: + expand_factor = (lora_diff.shape[1] - weight.shape[1]) + weight = torch.nn.functional.pad(weight, (0, expand_factor), mode='constant', value=0) + elif weight.shape[1] > lora_diff.shape[1]: + # expand factor should be 1*64 (for FluxTools Canny or Depth), or 5*64 (for FluxTools Fill) + expand_factor = (weight.shape[1] - lora_diff.shape[1]) + lora_diff = torch.nn.functional.pad(lora_diff, (0, expand_factor), mode='constant', value=0) + if dora_scale is not None: weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) raise e @@ -236,23 +255,45 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t except Exception as e: print("ERROR {} {} {}".format(patch_type, key, e)) raise e - elif patch_type == "glora": - if v[4] is not None: - alpha = v[4] / v[0].shape[0] - else: - alpha = 1.0 + elif patch_type == "glora": dora_scale = v[5] + + old_glora = False + if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: + old_glora = True + + if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: + if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + pass + else: + old_glora = False a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype) a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype) b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype) b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype) + if v[4] is None: + alpha = 1.0 + else: + if old_glora: + alpha = v[4] / v[0].shape[0] + else: + alpha = v[4] / v[1].shape[0] + try: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) + if old_glora: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=computation_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + else: + if weight.dim() > 2: + lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape) + else: + lora_diff = torch.mm(torch.mm(weight.to(dtype=computation_dtype), a1), a2).reshape(weight.shape) + lora_diff += torch.mm(b1, b2).reshape(weight.shape) + if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -299,7 +340,7 @@ class LoraLoader: self.loaded_hash = str([]) @torch.inference_mode() - def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh = False): + def refresh(self, lora_patches, offload_device=torch.device('cpu'), force_refresh=False): hashes = str(list(lora_patches.keys())) if hashes == self.loaded_hash and not force_refresh: diff --git a/packages_3rdparty/comfyui_lora_collection/lora.py b/packages_3rdparty/comfyui_lora_collection/lora.py index 59e3e791..89beaf0e 100644 --- a/packages_3rdparty/comfyui_lora_collection/lora.py +++ b/packages_3rdparty/comfyui_lora_collection/lora.py @@ -30,6 +30,19 @@ LORA_CLIP_MAP = { def load_lora(lora, to_load): + # BFL loras for Flux; from ComfyUI: comfy/lora_convert.py + def convert_lora_bfl_control(sd): + import torch + sd_out = {} + for k in sd: + k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight")) + sd_out[k_to] = sd[k] + + return sd_out + + if "img_in.lora_A.weight" in lora and "single_blocks.0.norm.key_norm.scale" in lora: + lora = convert_lora_bfl_control(lora) + patch_dict = {} loaded_keys = set() for x in to_load: @@ -189,6 +202,12 @@ def load_lora(lora, to_load): patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) loaded_keys.add(diff_bias_name) + set_weight_name = "{}.set_weight".format(x) + set_weight = lora.get(set_weight_name, None) + if set_weight is not None: + patch_dict[to_load[x]] = ("set", (set_weight,)) + loaded_keys.add(set_weight_name) + remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} return patch_dict, remaining_dict @@ -269,11 +288,13 @@ def model_lora_keys_unet(model, key_map={}): sdk = sd.keys() for k in sdk: - if k.startswith("diffusion_model.") and k.endswith(".weight"): - key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") - key_map["lora_unet_{}".format(key_lora)] = k - key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config - key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + if k.startswith("diffusion_model."): + if k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lora_unet_{}".format(key_lora)] = k + key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + else: + key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names diffusers_keys = utils.unet_to_diffusers(model.diffusion_model.config) for k in diffusers_keys: @@ -281,7 +302,8 @@ def model_lora_keys_unet(model, key_map={}): unet_key = "diffusion_model.{}".format(diffusers_keys[k]) key_lora = k[:-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = unet_key - + key_map["lycoris_{}".format(key_lora)] = unet_key #simpletuner lycoris format + diffusers_lora_prefix = ["", "unet."] for p in diffusers_lora_prefix: diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_")) @@ -289,19 +311,19 @@ def model_lora_keys_unet(model, key_map={}): diffusers_lora_key = diffusers_lora_key[:-2] key_map[diffusers_lora_key] = unet_key - # if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3 - # diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.") - # for k in diffusers_keys: - # if k.endswith(".weight"): - # to = diffusers_keys[k] - # key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format - # key_map[key_lora] = to - # - # key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others? - # key_map[key_lora] = to - # - # key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora - # key_map[key_lora] = to + # if 'stable-diffusion-3' in model.config.huggingface_repo.lower(): #Diffusers lora SD3 + # diffusers_keys = utils.mmdit_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.") + # for k in diffusers_keys: + # if k.endswith(".weight"): + # to = diffusers_keys[k] + # key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format + # key_map[key_lora] = to + + # key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others? + # key_map[key_lora] = to + + # key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora + # key_map[key_lora] = to # # if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow # diffusers_keys = utils.auraflow_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")