diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index a07ffc5f..1d215c92 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -49,24 +49,33 @@ def model_lora_keys_unet(model, key_map={}): @torch.inference_mode() -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype): - # Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33 - +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function): + # Modified from https://github.com/comfyanonymous/ComfyUI/blob/80a44b97f5cbcb890896e2b9e65d177f1ac6a588/comfy/weight_adapter/base.py#L42 dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype) lora_diff *= alpha - weight_calc = weight + lora_diff.type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) + weight_calc = weight + function(lora_diff).type(weight.dtype) + + wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0] + if wd_on_output_axis: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[0], *[1] * (weight.dim() - 1)) + ) + else: + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + weight_norm = weight_norm + torch.finfo(weight.dtype).eps weight_calc *= (dora_scale / weight_norm).type(weight.dtype) if strength != 1.0: weight_calc -= weight - weight += strength * weight_calc + weight += strength * (weight_calc) else: weight[:] = weight_calc return weight @@ -163,7 +172,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t lora_diff = torch.nn.functional.pad(lora_diff, (0, expand_factor), mode='constant', value=0) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) @@ -211,7 +220,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -249,7 +258,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: