mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Support wd_on_output for DoRA (#2856)
This commit is contained in:
@@ -49,24 +49,33 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype):
|
||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function):
|
||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/80a44b97f5cbcb890896e2b9e65d177f1ac6a588/comfy/weight_adapter/base.py#L42
|
||||
dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||
|
||||
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
||||
if wd_on_output_axis:
|
||||
weight_norm = (
|
||||
weight.reshape(weight.shape[0], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||
)
|
||||
else:
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * weight_calc
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
@@ -163,7 +172,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
lora_diff = torch.nn.functional.pad(lora_diff, (0, expand_factor), mode='constant', value=0)
|
||||
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
|
||||
@@ -211,7 +220,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
@@ -249,7 +258,7 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype))
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user