mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 11:09:50 +00:00
* Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer
291 lines
9.5 KiB
Python
291 lines
9.5 KiB
Python
import logging
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import comfy.model_management
|
|
from .base import WeightAdapterBase, weight_decompose
|
|
|
|
|
|
class GLoRAAdapter(WeightAdapterBase):
|
|
name = "glora"
|
|
|
|
def __init__(self, loaded_keys, weights):
|
|
self.loaded_keys = loaded_keys
|
|
self.weights = weights
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
x: str,
|
|
lora: dict[str, torch.Tensor],
|
|
alpha: float,
|
|
dora_scale: torch.Tensor,
|
|
loaded_keys: set[str] = None,
|
|
) -> Optional["GLoRAAdapter"]:
|
|
if loaded_keys is None:
|
|
loaded_keys = set()
|
|
a1_name = "{}.a1.weight".format(x)
|
|
a2_name = "{}.a2.weight".format(x)
|
|
b1_name = "{}.b1.weight".format(x)
|
|
b2_name = "{}.b2.weight".format(x)
|
|
if a1_name in lora:
|
|
weights = (
|
|
lora[a1_name],
|
|
lora[a2_name],
|
|
lora[b1_name],
|
|
lora[b2_name],
|
|
alpha,
|
|
dora_scale,
|
|
)
|
|
loaded_keys.add(a1_name)
|
|
loaded_keys.add(a2_name)
|
|
loaded_keys.add(b1_name)
|
|
loaded_keys.add(b2_name)
|
|
return cls(loaded_keys, weights)
|
|
else:
|
|
return None
|
|
|
|
def calculate_weight(
|
|
self,
|
|
weight,
|
|
key,
|
|
strength,
|
|
strength_model,
|
|
offset,
|
|
function,
|
|
intermediate_dtype=torch.float32,
|
|
original_weight=None,
|
|
):
|
|
v = self.weights
|
|
dora_scale = v[5]
|
|
|
|
old_glora = False
|
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
|
rank = v[0].shape[0]
|
|
old_glora = True
|
|
|
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
|
if (
|
|
old_glora
|
|
and v[1].shape[0] == weight.shape[0]
|
|
and weight.shape[0] == weight.shape[1]
|
|
):
|
|
pass
|
|
else:
|
|
old_glora = False
|
|
rank = v[1].shape[0]
|
|
|
|
a1 = comfy.model_management.cast_to_device(
|
|
v[0].flatten(start_dim=1), weight.device, intermediate_dtype
|
|
)
|
|
a2 = comfy.model_management.cast_to_device(
|
|
v[1].flatten(start_dim=1), weight.device, intermediate_dtype
|
|
)
|
|
b1 = comfy.model_management.cast_to_device(
|
|
v[2].flatten(start_dim=1), weight.device, intermediate_dtype
|
|
)
|
|
b2 = comfy.model_management.cast_to_device(
|
|
v[3].flatten(start_dim=1), weight.device, intermediate_dtype
|
|
)
|
|
|
|
if v[4] is not None:
|
|
alpha = v[4] / rank
|
|
else:
|
|
alpha = 1.0
|
|
|
|
try:
|
|
if old_glora:
|
|
lora_diff = (
|
|
torch.mm(b2, b1)
|
|
+ torch.mm(
|
|
torch.mm(
|
|
weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2
|
|
),
|
|
a1,
|
|
)
|
|
).reshape(
|
|
weight.shape
|
|
) # old lycoris glora
|
|
else:
|
|
if weight.dim() > 2:
|
|
lora_diff = torch.einsum(
|
|
"o i ..., i j -> o j ...",
|
|
torch.einsum(
|
|
"o i ..., i j -> o j ...",
|
|
weight.to(dtype=intermediate_dtype),
|
|
a1,
|
|
),
|
|
a2,
|
|
).reshape(weight.shape)
|
|
else:
|
|
lora_diff = torch.mm(
|
|
torch.mm(weight.to(dtype=intermediate_dtype), a1), a2
|
|
).reshape(weight.shape)
|
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
|
|
|
if dora_scale is not None:
|
|
weight = weight_decompose(
|
|
dora_scale,
|
|
weight,
|
|
lora_diff,
|
|
alpha,
|
|
strength,
|
|
intermediate_dtype,
|
|
function,
|
|
)
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
return weight
|
|
|
|
def _compute_paths(self, x: torch.Tensor):
|
|
"""
|
|
Compute A path and B path outputs for GLoRA bypass.
|
|
|
|
GLoRA: f(x) = Wx + WAx + Bx
|
|
- A path: a1(a2(x)) - modifies input to base forward
|
|
- B path: b1(b2(x)) - additive component
|
|
|
|
Note:
|
|
Does not access original model weights - bypass mode is designed
|
|
for quantized models where weights may not be accessible.
|
|
|
|
Returns: (a_out, b_out)
|
|
"""
|
|
v = self.weights
|
|
# v = (a1, a2, b1, b2, alpha, dora_scale)
|
|
a1 = v[0]
|
|
a2 = v[1]
|
|
b1 = v[2]
|
|
b2 = v[3]
|
|
alpha = v[4]
|
|
|
|
dtype = x.dtype
|
|
|
|
# Cast dtype (weights should already be on correct device from inject())
|
|
a1 = a1.to(dtype=dtype)
|
|
a2 = a2.to(dtype=dtype)
|
|
b1 = b1.to(dtype=dtype)
|
|
b2 = b2.to(dtype=dtype)
|
|
|
|
# Determine rank and scale
|
|
# Check for old vs new glora format
|
|
old_glora = False
|
|
if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]:
|
|
rank = a1.shape[0]
|
|
old_glora = True
|
|
|
|
if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]:
|
|
if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]:
|
|
pass
|
|
else:
|
|
old_glora = False
|
|
rank = a2.shape[0]
|
|
|
|
if alpha is not None:
|
|
scale = alpha / rank
|
|
else:
|
|
scale = 1.0
|
|
|
|
# Apply multiplier
|
|
multiplier = getattr(self, "multiplier", 1.0)
|
|
scale = scale * multiplier
|
|
|
|
# Use module info from bypass injection, not input tensor shape
|
|
is_conv = getattr(self, "is_conv", False)
|
|
conv_dim = getattr(self, "conv_dim", 0)
|
|
kw_dict = getattr(self, "kw_dict", {})
|
|
|
|
if is_conv:
|
|
# Conv case - conv_dim is 1/2/3 for conv1d/2d/3d
|
|
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
|
|
|
# Get module's stride/padding for spatial dimension handling
|
|
module_stride = kw_dict.get("stride", (1,) * conv_dim)
|
|
module_padding = kw_dict.get("padding", (0,) * conv_dim)
|
|
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
|
in_channels = getattr(self, "in_channels", None)
|
|
|
|
# Ensure weights are in conv shape
|
|
# a1, a2, b1 are always 1x1 kernels
|
|
if a1.ndim == 2:
|
|
a1 = a1.view(*a1.shape, *([1] * conv_dim))
|
|
if a2.ndim == 2:
|
|
a2 = a2.view(*a2.shape, *([1] * conv_dim))
|
|
if b1.ndim == 2:
|
|
b1 = b1.view(*b1.shape, *([1] * conv_dim))
|
|
# b2 has actual kernel_size (like LoRA down)
|
|
if b2.ndim == 2:
|
|
if in_channels is not None:
|
|
b2 = b2.view(b2.shape[0], in_channels, *kernel_size)
|
|
else:
|
|
b2 = b2.view(*b2.shape, *([1] * conv_dim))
|
|
|
|
# A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x
|
|
a2_out = conv_fn(x, a2)
|
|
a_out = conv_fn(a2_out, a1) * scale
|
|
|
|
# B path: b2(x) with kernel/stride/padding -> b1(...) 1x1
|
|
b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding)
|
|
b_out = conv_fn(b2_out, b1) * scale
|
|
else:
|
|
# Linear case
|
|
if old_glora:
|
|
# Old format: a1 @ a2 @ x, b2 @ b1
|
|
a_out = F.linear(F.linear(x, a2), a1) * scale
|
|
b_out = F.linear(F.linear(x, b1), b2) * scale
|
|
else:
|
|
# New format: x @ a1 @ a2, b1 @ b2
|
|
a_out = F.linear(F.linear(x, a1), a2) * scale
|
|
b_out = F.linear(F.linear(x, b2), b1) * scale
|
|
|
|
return a_out, b_out
|
|
|
|
def bypass_forward(
|
|
self,
|
|
org_forward: Callable,
|
|
x: torch.Tensor,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
GLoRA bypass forward: f(x + a(x)) + b(x)
|
|
|
|
Unlike standard adapters, GLoRA modifies the input to the base forward
|
|
AND adds the B path output.
|
|
|
|
Note:
|
|
Does not access original model weights - bypass mode is designed
|
|
for quantized models where weights may not be accessible.
|
|
|
|
Reference: LyCORIS GLoRAModule._bypass_forward
|
|
"""
|
|
a_out, b_out = self._compute_paths(x)
|
|
|
|
# Call base forward with modified input
|
|
base_out = org_forward(x + a_out, *args, **kwargs)
|
|
|
|
# Add B path
|
|
return base_out + b_out
|
|
|
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
For GLoRA, h() returns the B path output.
|
|
|
|
Note:
|
|
GLoRA's full bypass requires overriding bypass_forward() since
|
|
it also modifies the input to org_forward. This h() is provided for
|
|
compatibility but bypass_forward() should be used for correct behavior.
|
|
|
|
Does not access original model weights - bypass mode is designed
|
|
for quantized models where weights may not be accessible.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
base_out: Output from base forward (unused, for API consistency)
|
|
"""
|
|
_, b_out = self._compute_paths(x)
|
|
return b_out
|