Support wd_on_output for DoRA (#2856)

This commit is contained in:
machina
2025-05-05 08:32:33 -05:00
committed by GitHub
parent 0b26121335
commit 0ced1d0cd0

View File

@@ -49,24 +49,33 @@ def model_lora_keys_unet(model, key_map={}):
@torch.inference_mode()
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/80a44b97f5cbcb890896e2b9e65d177f1ac6a588/comfy/weight_adapter/base.py#L42
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc = weight + function(lora_diff).type(weight.dtype)
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
if wd_on_output_axis:
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
)
else:
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * weight_calc
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
@@ -163,7 +172,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
lora_diff = torch.nn.functional.pad(lora_diff, (0, expand_factor), mode='constant', value=0)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
@@ -211,7 +220,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
@@ -249,7 +258,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e: