mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-19 06:30:07 +00:00
* lora: add weight shape calculations. This lets the loader know if a lora will change the shape of a weight so it can take appropriate action. * MPDynamic: force load flux img_in weight This weight is a bit special, in that the lora changes its geometry. This is rather unique, not handled by existing estimate and doesn't work for either offloading or dynamic_vram. Fix for dynamic_vram as a special case. Ideally we can fully precalculate these lora geometry changes at load time, but just get these models working first.
397 lines
12 KiB
Python
397 lines
12 KiB
Python
from typing import Callable, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import comfy.model_management
|
|
|
|
|
|
class WeightAdapterBase:
|
|
"""
|
|
Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
|
|
|
|
Bypass Mode:
|
|
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
|
|
|
|
- h(x): Additive component (LoRA path). Returns delta to add to base output.
|
|
- g(y): Output transformation. Applied after base + h(x).
|
|
|
|
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
|
|
For OFT/BOFT: g = transform, h = 0
|
|
"""
|
|
|
|
name: str
|
|
loaded_keys: set[str]
|
|
weights: list[torch.Tensor]
|
|
|
|
# Attributes set by bypass system
|
|
multiplier: float = 1.0
|
|
shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel)
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
x: str,
|
|
lora: dict[str, torch.Tensor],
|
|
alpha: float,
|
|
dora_scale: torch.Tensor,
|
|
) -> Optional["WeightAdapterBase"]:
|
|
raise NotImplementedError
|
|
|
|
def to_train(self) -> "WeightAdapterTrainBase":
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
|
|
"""
|
|
weight: The original weight tensor to be modified.
|
|
*args: Additional arguments for configuration, such as rank, alpha etc.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def calculate_shape(
|
|
self,
|
|
key
|
|
):
|
|
return None
|
|
|
|
def calculate_weight(
|
|
self,
|
|
weight,
|
|
key,
|
|
strength,
|
|
strength_model,
|
|
offset,
|
|
function,
|
|
intermediate_dtype=torch.float32,
|
|
original_weight=None,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
# ===== Bypass Mode Methods =====
|
|
#
|
|
# IMPORTANT: Bypass mode is designed for quantized models where original weights
|
|
# may not be accessible in a usable format. Therefore, h() and bypass_forward()
|
|
# do NOT take org_weight as a parameter. All necessary information (out_channels,
|
|
# in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook.
|
|
|
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Additive bypass component: h(x, base_out)
|
|
|
|
Computes the adapter's contribution to be added to base forward output.
|
|
For adapters that only transform output (OFT/BOFT), returns zeros.
|
|
|
|
Note:
|
|
This method does NOT access original model weights. Bypass mode is
|
|
designed for quantized models where weights may not be in a usable format.
|
|
All shape info comes from module attributes set by BypassForwardHook.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
base_out: Output from base forward f(x), can be used for shape reference
|
|
|
|
Returns:
|
|
Delta tensor to add to base output. Shape matches base output.
|
|
|
|
Reference: LyCORIS LoConModule.bypass_forward_diff
|
|
"""
|
|
# Default: no additive component (for OFT/BOFT)
|
|
# Simply return zeros matching base_out shape
|
|
return torch.zeros_like(base_out)
|
|
|
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Output transformation: g(y)
|
|
|
|
Applied after base forward + h(x). For most adapters this is identity.
|
|
OFT/BOFT override this to apply orthogonal transformation.
|
|
|
|
Args:
|
|
y: Combined output (base + h(x))
|
|
|
|
Returns:
|
|
Transformed output
|
|
|
|
Reference: LyCORIS OFTModule applies orthogonal transform here
|
|
"""
|
|
# Default: identity (for LoRA/LoHa/LoKr)
|
|
return y
|
|
|
|
def bypass_forward(
|
|
self,
|
|
org_forward: Callable,
|
|
x: torch.Tensor,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Full bypass forward: g(f(x) + h(x, f(x)))
|
|
|
|
Note:
|
|
This method does NOT take org_weight/org_bias parameters. Bypass mode
|
|
is designed for quantized models where weights may not be accessible.
|
|
The original forward function handles weight access internally.
|
|
|
|
Args:
|
|
org_forward: Original module forward function
|
|
x: Input tensor
|
|
*args, **kwargs: Additional arguments for org_forward
|
|
|
|
Returns:
|
|
Output with adapter applied in bypass mode
|
|
|
|
Reference: LyCORIS LoConModule.bypass_forward
|
|
"""
|
|
# Base forward: f(x)
|
|
base_out = org_forward(x, *args, **kwargs)
|
|
|
|
# Additive component: h(x, base_out) - base_out provided for shape reference
|
|
h_out = self.h(x, base_out)
|
|
|
|
# Output transformation: g(base + h)
|
|
return self.g(base_out + h_out)
|
|
|
|
|
|
class WeightAdapterTrainBase(nn.Module):
|
|
"""
|
|
Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
|
|
|
|
Bypass Mode:
|
|
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
|
|
|
|
- h(x): Additive component (LoRA path). Returns delta to add to base output.
|
|
- g(y): Output transformation. Applied after base + h(x).
|
|
|
|
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
|
|
For OFT: g = transform, h = 0
|
|
|
|
Note:
|
|
Unlike WeightAdapterBase, TrainBase classes have simplified weight formats
|
|
with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition).
|
|
|
|
We follow the scheme of PR #7032
|
|
"""
|
|
|
|
# Attributes set by bypass system (BypassForwardHook)
|
|
# These are set before h()/g()/bypass_forward() are called
|
|
multiplier: float = 1.0
|
|
is_conv: bool = False
|
|
conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d
|
|
kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups
|
|
kernel_size: tuple = ()
|
|
in_channels: int = None
|
|
out_channels: int = None
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def __call__(self, w):
|
|
"""
|
|
Weight modification mode: returns modified weight.
|
|
|
|
Args:
|
|
w: The original weight tensor to be modified.
|
|
|
|
Returns:
|
|
Modified weight tensor.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
# ===== Bypass Mode Methods =====
|
|
|
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Additive bypass component: h(x, base_out)
|
|
|
|
Computes the adapter's contribution to be added to base forward output.
|
|
For adapters that only transform output (OFT), returns zeros.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
base_out: Output from base forward f(x), can be used for shape reference
|
|
|
|
Returns:
|
|
Delta tensor to add to base output. Shape matches base output.
|
|
|
|
Subclasses should override this method.
|
|
"""
|
|
raise NotImplementedError(
|
|
f"{self.__class__.__name__}.h() not implemented. "
|
|
"Subclasses must implement h() for bypass mode."
|
|
)
|
|
|
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Output transformation: g(y)
|
|
|
|
Applied after base forward + h(x). For most adapters this is identity.
|
|
OFT overrides this to apply orthogonal transformation.
|
|
|
|
Args:
|
|
y: Combined output (base + h(x))
|
|
|
|
Returns:
|
|
Transformed output
|
|
"""
|
|
# Default: identity (for LoRA/LoHa/LoKr)
|
|
return y
|
|
|
|
def bypass_forward(
|
|
self,
|
|
org_forward: Callable,
|
|
x: torch.Tensor,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Full bypass forward: g(f(x) + h(x, f(x)))
|
|
|
|
Args:
|
|
org_forward: Original module forward function
|
|
x: Input tensor
|
|
*args, **kwargs: Additional arguments for org_forward
|
|
|
|
Returns:
|
|
Output with adapter applied in bypass mode
|
|
"""
|
|
# Base forward: f(x)
|
|
base_out = org_forward(x, *args, **kwargs)
|
|
|
|
# Additive component: h(x, base_out) - base_out provided for shape reference
|
|
h_out = self.h(x, base_out)
|
|
|
|
# Output transformation: g(base + h)
|
|
return self.g(base_out + h_out)
|
|
|
|
def passive_memory_usage(self):
|
|
raise NotImplementedError("passive_memory_usage is not implemented")
|
|
|
|
def move_to(self, device):
|
|
self.to(device)
|
|
return self.passive_memory_usage()
|
|
|
|
|
|
def weight_decompose(
|
|
dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function
|
|
):
|
|
dora_scale = comfy.model_management.cast_to_device(
|
|
dora_scale, weight.device, intermediate_dtype
|
|
)
|
|
lora_diff *= alpha
|
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
|
|
|
wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
|
|
if wd_on_output_axis:
|
|
weight_norm = (
|
|
weight.reshape(weight.shape[0], -1)
|
|
.norm(dim=1, keepdim=True)
|
|
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
|
)
|
|
else:
|
|
weight_norm = (
|
|
weight_calc.transpose(0, 1)
|
|
.reshape(weight_calc.shape[1], -1)
|
|
.norm(dim=1, keepdim=True)
|
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
.transpose(0, 1)
|
|
)
|
|
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
|
|
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
if strength != 1.0:
|
|
weight_calc -= weight
|
|
weight += strength * (weight_calc)
|
|
else:
|
|
weight[:] = weight_calc
|
|
return weight
|
|
|
|
|
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
|
"""
|
|
Pad a tensor to a new shape with zeros.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The original tensor to be padded.
|
|
new_shape (List[int]): The desired shape of the padded tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
|
|
|
Note:
|
|
If the new shape is smaller than the original tensor in any dimension,
|
|
the original tensor will be truncated in that dimension.
|
|
"""
|
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
|
raise ValueError(
|
|
"The new shape must be larger than the original tensor in all dimensions"
|
|
)
|
|
|
|
if len(new_shape) != len(tensor.shape):
|
|
raise ValueError(
|
|
"The new shape must have the same number of dimensions as the original tensor"
|
|
)
|
|
|
|
# Create a new tensor filled with zeros
|
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
|
|
|
# Create slicing tuples for both tensors
|
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
|
|
|
# Copy the original tensor into the new tensor
|
|
padded_tensor[new_slices] = tensor[orig_slices]
|
|
|
|
return padded_tensor
|
|
|
|
|
|
def tucker_weight_from_conv(up, down, mid):
|
|
up = up.reshape(up.size(0), up.size(1))
|
|
down = down.reshape(down.size(0), down.size(1))
|
|
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
|
|
|
|
|
|
def tucker_weight(wa, wb, t):
|
|
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
|
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
|
|
|
|
|
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
|
"""
|
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
|
second value is higher or equal than first value.
|
|
|
|
examples)
|
|
factor
|
|
-1 2 4 8 16 ...
|
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
|
"""
|
|
|
|
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
|
|
m = factor
|
|
n = dimension // factor
|
|
if m > n:
|
|
n, m = m, n
|
|
return m, n
|
|
if factor < 0:
|
|
factor = dimension
|
|
m, n = 1, dimension
|
|
length = m + n
|
|
while m < n:
|
|
new_m = m + 1
|
|
while dimension % new_m != 0:
|
|
new_m += 1
|
|
new_n = dimension // new_m
|
|
if new_m + new_n > length or new_m > factor:
|
|
break
|
|
else:
|
|
m, n = new_m, new_n
|
|
if m > n:
|
|
n, m = m, n
|
|
return m, n
|