mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 11:09:50 +00:00
[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer
This commit is contained in:
100
comfy/sd.py
100
comfy/sd.py
@@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline
|
|||||||
import comfy.ldm.hunyuan_video.vae
|
import comfy.ldm.hunyuan_video.vae
|
||||||
import comfy.ldm.mmaudio.vae.autoencoder
|
import comfy.ldm.mmaudio.vae.autoencoder
|
||||||
import comfy.pixel_space_convert
|
import comfy.pixel_space_convert
|
||||||
|
import comfy.weight_adapter
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -101,6 +102,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
return (new_modelpatcher, new_clip)
|
return (new_modelpatcher, new_clip)
|
||||||
|
|
||||||
|
|
||||||
|
def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
|
"""
|
||||||
|
Load LoRA in bypass mode without modifying base model weights.
|
||||||
|
|
||||||
|
Instead of patching weights, this injects the LoRA computation into the
|
||||||
|
forward pass: output = base_forward(x) + lora_path(x)
|
||||||
|
|
||||||
|
Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches.
|
||||||
|
|
||||||
|
This is useful for training and when model weights are offloaded.
|
||||||
|
"""
|
||||||
|
key_map = {}
|
||||||
|
if model is not None:
|
||||||
|
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
|
if clip is not None:
|
||||||
|
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
|
|
||||||
|
logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries")
|
||||||
|
|
||||||
|
lora = comfy.lora_convert.convert_lora(lora)
|
||||||
|
loaded = comfy.lora.load_lora(lora, key_map)
|
||||||
|
|
||||||
|
logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries")
|
||||||
|
|
||||||
|
# Separate adapters (for bypass) from other patches (for regular patching)
|
||||||
|
bypass_patches = {} # WeightAdapterBase instances -> bypass mode
|
||||||
|
regular_patches = {} # diff, set, bias patches -> regular weight patching
|
||||||
|
|
||||||
|
for key, patch_data in loaded.items():
|
||||||
|
if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase):
|
||||||
|
bypass_patches[key] = patch_data
|
||||||
|
else:
|
||||||
|
regular_patches[key] = patch_data
|
||||||
|
|
||||||
|
logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches")
|
||||||
|
|
||||||
|
k = set()
|
||||||
|
k1 = set()
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
new_modelpatcher = model.clone()
|
||||||
|
|
||||||
|
# Apply regular patches (bias diff, weight diff, etc.) via normal patching
|
||||||
|
if regular_patches:
|
||||||
|
patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model)
|
||||||
|
k.update(patched_keys)
|
||||||
|
|
||||||
|
# Apply adapter patches via bypass injection
|
||||||
|
manager = comfy.weight_adapter.BypassInjectionManager()
|
||||||
|
model_sd_keys = set(new_modelpatcher.model.state_dict().keys())
|
||||||
|
|
||||||
|
for key, adapter in bypass_patches.items():
|
||||||
|
if key in model_sd_keys:
|
||||||
|
manager.add_adapter(key, adapter, strength=strength_model)
|
||||||
|
k.add(key)
|
||||||
|
else:
|
||||||
|
logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}")
|
||||||
|
|
||||||
|
injections = manager.create_injections(new_modelpatcher.model)
|
||||||
|
|
||||||
|
if manager.get_hook_count() > 0:
|
||||||
|
new_modelpatcher.set_injections("bypass_lora", injections)
|
||||||
|
else:
|
||||||
|
new_modelpatcher = None
|
||||||
|
|
||||||
|
if clip is not None:
|
||||||
|
new_clip = clip.clone()
|
||||||
|
|
||||||
|
# Apply regular patches to clip
|
||||||
|
if regular_patches:
|
||||||
|
patched_keys = new_clip.add_patches(regular_patches, strength_clip)
|
||||||
|
k1.update(patched_keys)
|
||||||
|
|
||||||
|
# Apply adapter patches via bypass injection
|
||||||
|
clip_manager = comfy.weight_adapter.BypassInjectionManager()
|
||||||
|
clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys())
|
||||||
|
|
||||||
|
for key, adapter in bypass_patches.items():
|
||||||
|
if key in clip_sd_keys:
|
||||||
|
clip_manager.add_adapter(key, adapter, strength=strength_clip)
|
||||||
|
k1.add(key)
|
||||||
|
|
||||||
|
clip_injections = clip_manager.create_injections(new_clip.cond_stage_model)
|
||||||
|
if clip_manager.get_hook_count() > 0:
|
||||||
|
new_clip.patcher.set_injections("bypass_lora", clip_injections)
|
||||||
|
else:
|
||||||
|
new_clip = None
|
||||||
|
|
||||||
|
for x in loaded:
|
||||||
|
if (x not in k) and (x not in k1):
|
||||||
|
patch_data = loaded[x]
|
||||||
|
patch_type = type(patch_data).__name__
|
||||||
|
if isinstance(patch_data, tuple):
|
||||||
|
patch_type = f"tuple({patch_data[0]})"
|
||||||
|
logging.warning(f"NOT LOADED: {x} (type={patch_type})")
|
||||||
|
|
||||||
|
return (new_modelpatcher, new_clip)
|
||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||||
if no_init:
|
if no_init:
|
||||||
|
|||||||
@@ -5,6 +5,11 @@ from .lokr import LoKrAdapter
|
|||||||
from .glora import GLoRAAdapter
|
from .glora import GLoRAAdapter
|
||||||
from .oft import OFTAdapter
|
from .oft import OFTAdapter
|
||||||
from .boft import BOFTAdapter
|
from .boft import BOFTAdapter
|
||||||
|
from .bypass import (
|
||||||
|
BypassInjectionManager,
|
||||||
|
BypassForwardHook,
|
||||||
|
create_bypass_injections_from_patches,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
adapters: list[type[WeightAdapterBase]] = [
|
adapters: list[type[WeightAdapterBase]] = [
|
||||||
@@ -31,4 +36,7 @@ __all__ = [
|
|||||||
"WeightAdapterTrainBase",
|
"WeightAdapterTrainBase",
|
||||||
"adapters",
|
"adapters",
|
||||||
"adapter_maps",
|
"adapter_maps",
|
||||||
|
"BypassInjectionManager",
|
||||||
|
"BypassForwardHook",
|
||||||
|
"create_bypass_injections_from_patches",
|
||||||
] + [a.__name__ for a in adapters]
|
] + [a.__name__ for a in adapters]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -7,12 +7,35 @@ import comfy.model_management
|
|||||||
|
|
||||||
|
|
||||||
class WeightAdapterBase:
|
class WeightAdapterBase:
|
||||||
|
"""
|
||||||
|
Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
|
||||||
|
|
||||||
|
Bypass Mode:
|
||||||
|
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
|
||||||
|
|
||||||
|
- h(x): Additive component (LoRA path). Returns delta to add to base output.
|
||||||
|
- g(y): Output transformation. Applied after base + h(x).
|
||||||
|
|
||||||
|
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
|
||||||
|
For OFT/BOFT: g = transform, h = 0
|
||||||
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
loaded_keys: set[str]
|
loaded_keys: set[str]
|
||||||
weights: list[torch.Tensor]
|
weights: list[torch.Tensor]
|
||||||
|
|
||||||
|
# Attributes set by bypass system
|
||||||
|
multiplier: float = 1.0
|
||||||
|
shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
) -> Optional["WeightAdapterBase"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_train(self) -> "WeightAdapterTrainBase":
|
def to_train(self) -> "WeightAdapterTrainBase":
|
||||||
@@ -39,18 +62,202 @@ class WeightAdapterBase:
|
|||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# ===== Bypass Mode Methods =====
|
||||||
|
#
|
||||||
|
# IMPORTANT: Bypass mode is designed for quantized models where original weights
|
||||||
|
# may not be accessible in a usable format. Therefore, h() and bypass_forward()
|
||||||
|
# do NOT take org_weight as a parameter. All necessary information (out_channels,
|
||||||
|
# in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook.
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component: h(x, base_out)
|
||||||
|
|
||||||
|
Computes the adapter's contribution to be added to base forward output.
|
||||||
|
For adapters that only transform output (OFT/BOFT), returns zeros.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method does NOT access original model weights. Bypass mode is
|
||||||
|
designed for quantized models where weights may not be in a usable format.
|
||||||
|
All shape info comes from module attributes set by BypassForwardHook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward f(x), can be used for shape reference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delta tensor to add to base output. Shape matches base output.
|
||||||
|
|
||||||
|
Reference: LyCORIS LoConModule.bypass_forward_diff
|
||||||
|
"""
|
||||||
|
# Default: no additive component (for OFT/BOFT)
|
||||||
|
# Simply return zeros matching base_out shape
|
||||||
|
return torch.zeros_like(base_out)
|
||||||
|
|
||||||
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Output transformation: g(y)
|
||||||
|
|
||||||
|
Applied after base forward + h(x). For most adapters this is identity.
|
||||||
|
OFT/BOFT override this to apply orthogonal transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Combined output (base + h(x))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed output
|
||||||
|
|
||||||
|
Reference: LyCORIS OFTModule applies orthogonal transform here
|
||||||
|
"""
|
||||||
|
# Default: identity (for LoRA/LoHa/LoKr)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def bypass_forward(
|
||||||
|
self,
|
||||||
|
org_forward: Callable,
|
||||||
|
x: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Full bypass forward: g(f(x) + h(x, f(x)))
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method does NOT take org_weight/org_bias parameters. Bypass mode
|
||||||
|
is designed for quantized models where weights may not be accessible.
|
||||||
|
The original forward function handles weight access internally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_forward: Original module forward function
|
||||||
|
x: Input tensor
|
||||||
|
*args, **kwargs: Additional arguments for org_forward
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output with adapter applied in bypass mode
|
||||||
|
|
||||||
|
Reference: LyCORIS LoConModule.bypass_forward
|
||||||
|
"""
|
||||||
|
# Base forward: f(x)
|
||||||
|
base_out = org_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
|
# Additive component: h(x, base_out) - base_out provided for shape reference
|
||||||
|
h_out = self.h(x, base_out)
|
||||||
|
|
||||||
|
# Output transformation: g(base + h)
|
||||||
|
return self.g(base_out + h_out)
|
||||||
|
|
||||||
|
|
||||||
class WeightAdapterTrainBase(nn.Module):
|
class WeightAdapterTrainBase(nn.Module):
|
||||||
# We follow the scheme of PR #7032
|
"""
|
||||||
|
Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
|
||||||
|
|
||||||
|
Bypass Mode:
|
||||||
|
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
|
||||||
|
|
||||||
|
- h(x): Additive component (LoRA path). Returns delta to add to base output.
|
||||||
|
- g(y): Output transformation. Applied after base + h(x).
|
||||||
|
|
||||||
|
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
|
||||||
|
For OFT: g = transform, h = 0
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Unlike WeightAdapterBase, TrainBase classes have simplified weight formats
|
||||||
|
with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition).
|
||||||
|
|
||||||
|
We follow the scheme of PR #7032
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Attributes set by bypass system (BypassForwardHook)
|
||||||
|
# These are set before h()/g()/bypass_forward() are called
|
||||||
|
multiplier: float = 1.0
|
||||||
|
is_conv: bool = False
|
||||||
|
conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d
|
||||||
|
kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups
|
||||||
|
kernel_size: tuple = ()
|
||||||
|
in_channels: int = None
|
||||||
|
out_channels: int = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def __call__(self, w):
|
def __call__(self, w):
|
||||||
"""
|
"""
|
||||||
w: The original weight tensor to be modified.
|
Weight modification mode: returns modified weight.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w: The original weight tensor to be modified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified weight tensor.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# ===== Bypass Mode Methods =====
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component: h(x, base_out)
|
||||||
|
|
||||||
|
Computes the adapter's contribution to be added to base forward output.
|
||||||
|
For adapters that only transform output (OFT), returns zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward f(x), can be used for shape reference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delta tensor to add to base output. Shape matches base output.
|
||||||
|
|
||||||
|
Subclasses should override this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__}.h() not implemented. "
|
||||||
|
"Subclasses must implement h() for bypass mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Output transformation: g(y)
|
||||||
|
|
||||||
|
Applied after base forward + h(x). For most adapters this is identity.
|
||||||
|
OFT overrides this to apply orthogonal transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Combined output (base + h(x))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed output
|
||||||
|
"""
|
||||||
|
# Default: identity (for LoRA/LoHa/LoKr)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def bypass_forward(
|
||||||
|
self,
|
||||||
|
org_forward: Callable,
|
||||||
|
x: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Full bypass forward: g(f(x) + h(x, f(x)))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_forward: Original module forward function
|
||||||
|
x: Input tensor
|
||||||
|
*args, **kwargs: Additional arguments for org_forward
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output with adapter applied in bypass mode
|
||||||
|
"""
|
||||||
|
# Base forward: f(x)
|
||||||
|
base_out = org_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
|
# Additive component: h(x, base_out) - base_out provided for shape reference
|
||||||
|
h_out = self.h(x, base_out)
|
||||||
|
|
||||||
|
# Output transformation: g(base + h)
|
||||||
|
return self.g(base_out + h_out)
|
||||||
|
|
||||||
def passive_memory_usage(self):
|
def passive_memory_usage(self):
|
||||||
raise NotImplementedError("passive_memory_usage is not implemented")
|
raise NotImplementedError("passive_memory_usage is not implemented")
|
||||||
|
|
||||||
@@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module):
|
|||||||
return self.passive_memory_usage()
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
def weight_decompose(
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function
|
||||||
|
):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(
|
||||||
|
dora_scale, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
|
||||||
@@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
the original tensor will be truncated in that dimension.
|
the original tensor will be truncated in that dimension.
|
||||||
"""
|
"""
|
||||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
raise ValueError(
|
||||||
|
"The new shape must be larger than the original tensor in all dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
if len(new_shape) != len(tensor.shape):
|
if len(new_shape) != len(tensor.shape):
|
||||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
raise ValueError(
|
||||||
|
"The new shape must have the same number of dimensions as the original tensor"
|
||||||
|
)
|
||||||
|
|
||||||
# Create a new tensor filled with zeros
|
# Create a new tensor filled with zeros
|
||||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|||||||
@@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
alpha = v[2]
|
alpha = v[2]
|
||||||
dora_scale = v[3]
|
dora_scale = v[3]
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
blocks = comfy.model_management.cast_to_device(
|
||||||
|
blocks, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
if rescale is not None:
|
if rescale is not None:
|
||||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
rescale = comfy.model_management.cast_to_device(
|
||||||
|
rescale, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
|
||||||
boft_m, block_num, boft_b, *_ = blocks.shape
|
boft_m, block_num, boft_b, *_ = blocks.shape
|
||||||
|
|
||||||
@@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
# for Q = -Q^T
|
# for Q = -Q^T
|
||||||
q = blocks - blocks.transpose(-1, -2)
|
q = blocks - blocks.transpose(-1, -2)
|
||||||
normed_q = q
|
normed_q = q
|
||||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||||
q_norm = torch.norm(q) + 1e-8
|
q_norm = torch.norm(q) + 1e-8
|
||||||
if q_norm > alpha:
|
if q_norm > alpha:
|
||||||
normed_q = q * alpha / q_norm
|
normed_q = q * alpha / q_norm
|
||||||
@@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
r = r.to(weight)
|
r = r.to(weight)
|
||||||
inp = org = weight
|
inp = org = weight
|
||||||
|
|
||||||
r_b = boft_b//2
|
r_b = boft_b // 2
|
||||||
for i in range(boft_m):
|
for i in range(boft_m):
|
||||||
bi = r[i]
|
bi = r[i]
|
||||||
g = 2
|
g = 2
|
||||||
k = 2**i * r_b
|
k = 2**i * r_b
|
||||||
if strength != 1:
|
if strength != 1:
|
||||||
bi = bi * strength + (1-strength) * I
|
bi = bi * strength + (1 - strength) * I
|
||||||
inp = (
|
inp = (
|
||||||
inp.unflatten(0, (-1, g, k))
|
inp.unflatten(0, (-1, g, k))
|
||||||
.transpose(1, 2)
|
.transpose(1, 2)
|
||||||
@@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
)
|
)
|
||||||
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
||||||
inp = (
|
inp = (
|
||||||
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
|
inp.flatten(0, 1)
|
||||||
|
.unflatten(0, (-1, k, g))
|
||||||
|
.transpose(1, 2)
|
||||||
|
.flatten(0, 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
if rescale is not None:
|
if rescale is not None:
|
||||||
inp = inp * rescale
|
inp = inp * rescale
|
||||||
|
|
||||||
lora_diff = inp - org
|
lora_diff = inp - org
|
||||||
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
lora_diff = comfy.model_management.cast_to_device(
|
||||||
|
lora_diff, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weight += function((strength * lora_diff).type(weight.dtype))
|
weight += function((strength * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def _get_orthogonal_matrices(self, device, dtype):
|
||||||
|
"""Compute the orthogonal rotation matrices R from BOFT blocks."""
|
||||||
|
v = self.weights
|
||||||
|
blocks = v[0].to(device=device, dtype=dtype)
|
||||||
|
alpha = v[2]
|
||||||
|
if alpha is None:
|
||||||
|
alpha = 0
|
||||||
|
|
||||||
|
boft_m, block_num, boft_b, _ = blocks.shape
|
||||||
|
I = torch.eye(boft_b, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Q = blocks - blocks^T (skew-symmetric)
|
||||||
|
q = blocks - blocks.transpose(-1, -2)
|
||||||
|
normed_q = q
|
||||||
|
|
||||||
|
# Apply constraint if alpha > 0
|
||||||
|
if alpha > 0:
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > alpha:
|
||||||
|
normed_q = q * alpha / q_norm
|
||||||
|
|
||||||
|
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
return r, boft_m, boft_b
|
||||||
|
|
||||||
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Output transformation for BOFT: applies butterfly orthogonal transform.
|
||||||
|
|
||||||
|
BOFT uses multiple stages of butterfly-structured orthogonal transforms.
|
||||||
|
|
||||||
|
Reference: LyCORIS ButterflyOFTModule._bypass_forward
|
||||||
|
"""
|
||||||
|
v = self.weights
|
||||||
|
rescale = v[1]
|
||||||
|
|
||||||
|
r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype)
|
||||||
|
r_b = boft_b // 2
|
||||||
|
|
||||||
|
# Apply multiplier
|
||||||
|
multiplier = getattr(self, "multiplier", 1.0)
|
||||||
|
I = torch.eye(boft_b, device=y.device, dtype=y.dtype)
|
||||||
|
|
||||||
|
# Use module info from bypass injection to determine conv vs linear
|
||||||
|
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||||
|
y = y.transpose(1, -1)
|
||||||
|
|
||||||
|
# Apply butterfly transform stages
|
||||||
|
inp = y
|
||||||
|
for i in range(boft_m):
|
||||||
|
bi = r[i] # (block_num, boft_b, boft_b)
|
||||||
|
g = 2
|
||||||
|
k = 2**i * r_b
|
||||||
|
|
||||||
|
# Interpolate with identity based on multiplier
|
||||||
|
if multiplier != 1:
|
||||||
|
bi = bi * multiplier + (1 - multiplier) * I
|
||||||
|
|
||||||
|
# Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten
|
||||||
|
inp = (
|
||||||
|
inp.unflatten(-1, (-1, g, k))
|
||||||
|
.transpose(-2, -1)
|
||||||
|
.flatten(-3)
|
||||||
|
.unflatten(-1, (-1, boft_b))
|
||||||
|
)
|
||||||
|
# Apply block-diagonal orthogonal transform
|
||||||
|
inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp)
|
||||||
|
# Reshape back
|
||||||
|
inp = (
|
||||||
|
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply rescale if present
|
||||||
|
if rescale is not None:
|
||||||
|
rescale = rescale.to(device=y.device, dtype=y.dtype)
|
||||||
|
inp = inp * rescale.transpose(0, -1)
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||||
|
inp = inp.transpose(1, -1)
|
||||||
|
|
||||||
|
return inp
|
||||||
|
|||||||
437
comfy/weight_adapter/bypass.py
Normal file
437
comfy/weight_adapter/bypass.py
Normal file
@@ -0,0 +1,437 @@
|
|||||||
|
"""
|
||||||
|
Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.)
|
||||||
|
|
||||||
|
Bypass mode applies adapters during forward pass without modifying base weights:
|
||||||
|
bypass(f)(x) = g(f(x) + h(x))
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- f(x): Original layer forward
|
||||||
|
- h(x): Additive component from adapter (LoRA path)
|
||||||
|
- g(y): Output transformation (identity for most adapters)
|
||||||
|
|
||||||
|
This is useful for:
|
||||||
|
- Training with gradient checkpointing
|
||||||
|
- Avoiding weight modifications when weights are offloaded
|
||||||
|
- Supporting multiple adapters with different strengths dynamically
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||||
|
from comfy.patcher_extension import PatcherInjection
|
||||||
|
|
||||||
|
# Type alias for adapters that support bypass mode
|
||||||
|
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_type_info(module: nn.Module) -> dict:
|
||||||
|
"""
|
||||||
|
Determine module type and extract conv parameters from module class.
|
||||||
|
|
||||||
|
This is more reliable than checking weight.ndim, especially for quantized layers
|
||||||
|
where weight shape might be different.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys: is_conv, conv_dim, stride, padding, dilation, groups
|
||||||
|
"""
|
||||||
|
info = {
|
||||||
|
"is_conv": False,
|
||||||
|
"conv_dim": 0,
|
||||||
|
"stride": (1,),
|
||||||
|
"padding": (0,),
|
||||||
|
"dilation": (1,),
|
||||||
|
"groups": 1,
|
||||||
|
"kernel_size": (1,),
|
||||||
|
"in_channels": None,
|
||||||
|
"out_channels": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine conv type
|
||||||
|
if isinstance(module, nn.Conv1d):
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 1
|
||||||
|
elif isinstance(module, nn.Conv2d):
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 2
|
||||||
|
elif isinstance(module, nn.Conv3d):
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 3
|
||||||
|
elif isinstance(module, nn.Linear):
|
||||||
|
info["is_conv"] = False
|
||||||
|
info["conv_dim"] = 0
|
||||||
|
else:
|
||||||
|
# Try to infer from class name for custom/quantized layers
|
||||||
|
class_name = type(module).__name__.lower()
|
||||||
|
if "conv3d" in class_name:
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 3
|
||||||
|
elif "conv2d" in class_name:
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 2
|
||||||
|
elif "conv1d" in class_name:
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 1
|
||||||
|
elif "conv" in class_name:
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 2
|
||||||
|
|
||||||
|
# Extract conv parameters if it's a conv layer
|
||||||
|
if info["is_conv"]:
|
||||||
|
# Try to get stride, padding, dilation, groups, kernel_size from module
|
||||||
|
info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"])
|
||||||
|
info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"])
|
||||||
|
info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"])
|
||||||
|
info["groups"] = getattr(module, "groups", 1)
|
||||||
|
info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"])
|
||||||
|
info["in_channels"] = getattr(module, "in_channels", None)
|
||||||
|
info["out_channels"] = getattr(module, "out_channels", None)
|
||||||
|
|
||||||
|
# Ensure they're tuples
|
||||||
|
if isinstance(info["stride"], int):
|
||||||
|
info["stride"] = (info["stride"],) * info["conv_dim"]
|
||||||
|
if isinstance(info["padding"], int):
|
||||||
|
info["padding"] = (info["padding"],) * info["conv_dim"]
|
||||||
|
if isinstance(info["dilation"], int):
|
||||||
|
info["dilation"] = (info["dilation"],) * info["conv_dim"]
|
||||||
|
if isinstance(info["kernel_size"], int):
|
||||||
|
info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"]
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
class BypassForwardHook:
|
||||||
|
"""
|
||||||
|
Hook that wraps a layer's forward to apply adapter in bypass mode.
|
||||||
|
|
||||||
|
Stores the original forward and replaces it with bypass version.
|
||||||
|
|
||||||
|
Supports both:
|
||||||
|
- WeightAdapterBase: Inference adapters (uses self.weights tuple)
|
||||||
|
- WeightAdapterTrainBase: Training adapters (nn.Module with parameters)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: nn.Module,
|
||||||
|
adapter: BypassAdapter,
|
||||||
|
multiplier: float = 1.0,
|
||||||
|
):
|
||||||
|
self.module = module
|
||||||
|
self.adapter = adapter
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.original_forward = None
|
||||||
|
|
||||||
|
# Determine layer type and conv params from module class (works for quantized layers)
|
||||||
|
module_info = get_module_type_info(module)
|
||||||
|
|
||||||
|
# Set multiplier and layer type info on adapter for use in h()
|
||||||
|
adapter.multiplier = multiplier
|
||||||
|
adapter.is_conv = module_info["is_conv"]
|
||||||
|
adapter.conv_dim = module_info["conv_dim"]
|
||||||
|
adapter.kernel_size = module_info["kernel_size"]
|
||||||
|
adapter.in_channels = module_info["in_channels"]
|
||||||
|
adapter.out_channels = module_info["out_channels"]
|
||||||
|
# Store kw_dict for conv operations (like LyCORIS extra_args)
|
||||||
|
if module_info["is_conv"]:
|
||||||
|
adapter.kw_dict = {
|
||||||
|
"stride": module_info["stride"],
|
||||||
|
"padding": module_info["padding"],
|
||||||
|
"dilation": module_info["dilation"],
|
||||||
|
"groups": module_info["groups"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
adapter.kw_dict = {}
|
||||||
|
|
||||||
|
def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||||
|
"""Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x))
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Bypass mode does NOT access original model weights (org_weight).
|
||||||
|
This is intentional - bypass mode is designed for quantized models
|
||||||
|
where weights may not be in a usable format. All necessary shape
|
||||||
|
information is provided via adapter attributes set during inject().
|
||||||
|
"""
|
||||||
|
# Check if adapter has custom bypass_forward (e.g., GLoRA)
|
||||||
|
adapter_bypass = getattr(self.adapter, "bypass_forward", None)
|
||||||
|
if adapter_bypass is not None:
|
||||||
|
# Check if it's overridden (not the base class default)
|
||||||
|
# Need to check both base classes since adapter could be either type
|
||||||
|
adapter_type = type(self.adapter)
|
||||||
|
is_default_bypass = (
|
||||||
|
adapter_type.bypass_forward is WeightAdapterBase.bypass_forward
|
||||||
|
or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward
|
||||||
|
)
|
||||||
|
if not is_default_bypass:
|
||||||
|
return adapter_bypass(self.original_forward, x, *args, **kwargs)
|
||||||
|
|
||||||
|
# Default bypass: g(f(x) + h(x, f(x)))
|
||||||
|
base_out = self.original_forward(x, *args, **kwargs)
|
||||||
|
h_out = self.adapter.h(x, base_out)
|
||||||
|
return self.adapter.g(base_out + h_out)
|
||||||
|
|
||||||
|
def inject(self):
|
||||||
|
"""Replace module forward with bypass version."""
|
||||||
|
if self.original_forward is not None:
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassHook] Already injected for {type(self.module).__name__}"
|
||||||
|
)
|
||||||
|
return # Already injected
|
||||||
|
|
||||||
|
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
||||||
|
device = None
|
||||||
|
dtype = None
|
||||||
|
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||||
|
device = self.module.weight.device
|
||||||
|
dtype = self.module.weight.dtype
|
||||||
|
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
||||||
|
device = self.module.W_q.device
|
||||||
|
dtype = self.module.W_q.dtype
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
self._move_adapter_weights_to_device(device, dtype)
|
||||||
|
|
||||||
|
self.original_forward = self.module.forward
|
||||||
|
self.module.forward = self._bypass_forward
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _move_adapter_weights_to_device(self, device, dtype=None):
|
||||||
|
"""Move adapter weights to specified device to avoid per-forward transfers.
|
||||||
|
|
||||||
|
Handles both:
|
||||||
|
- WeightAdapterBase: has self.weights tuple of tensors
|
||||||
|
- WeightAdapterTrainBase: nn.Module with parameters, uses .to() method
|
||||||
|
"""
|
||||||
|
adapter = self.adapter
|
||||||
|
|
||||||
|
# Check if adapter is an nn.Module (WeightAdapterTrainBase)
|
||||||
|
if isinstance(adapter, nn.Module):
|
||||||
|
# In training mode we don't touch dtype as trainer will handle it
|
||||||
|
adapter.to(device=device)
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassHook] Moved training adapter (nn.Module) to {device}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# WeightAdapterBase: handle self.weights tuple
|
||||||
|
if not hasattr(adapter, "weights") or adapter.weights is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
weights = adapter.weights
|
||||||
|
if isinstance(weights, (list, tuple)):
|
||||||
|
new_weights = []
|
||||||
|
for w in weights:
|
||||||
|
if isinstance(w, torch.Tensor):
|
||||||
|
if dtype is not None:
|
||||||
|
new_weights.append(w.to(device=device, dtype=dtype))
|
||||||
|
else:
|
||||||
|
new_weights.append(w.to(device=device))
|
||||||
|
else:
|
||||||
|
new_weights.append(w)
|
||||||
|
adapter.weights = (
|
||||||
|
tuple(new_weights) if isinstance(weights, tuple) else new_weights
|
||||||
|
)
|
||||||
|
elif isinstance(weights, torch.Tensor):
|
||||||
|
if dtype is not None:
|
||||||
|
adapter.weights = weights.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
adapter.weights = weights.to(device=device)
|
||||||
|
|
||||||
|
logging.debug(f"[BypassHook] Moved adapter weights to {device}")
|
||||||
|
|
||||||
|
def eject(self):
|
||||||
|
"""Restore original module forward."""
|
||||||
|
if self.original_forward is None:
|
||||||
|
logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}")
|
||||||
|
return # Not injected
|
||||||
|
|
||||||
|
self.module.forward = self.original_forward
|
||||||
|
self.original_forward = None
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BypassInjectionManager:
|
||||||
|
"""
|
||||||
|
Manages bypass mode injection for a collection of adapters.
|
||||||
|
|
||||||
|
Creates PatcherInjection objects that can be used with ModelPatcher.
|
||||||
|
|
||||||
|
Supports both inference adapters (WeightAdapterBase) and training adapters
|
||||||
|
(WeightAdapterTrainBase).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
manager = BypassInjectionManager()
|
||||||
|
manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8)
|
||||||
|
manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8)
|
||||||
|
|
||||||
|
injections = manager.create_injections(model)
|
||||||
|
model_patcher.set_injections("bypass_lora", injections)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.adapters: dict[str, tuple[BypassAdapter, float]] = {}
|
||||||
|
self.hooks: list[BypassForwardHook] = []
|
||||||
|
|
||||||
|
def add_adapter(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
adapter: BypassAdapter,
|
||||||
|
strength: float = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add an adapter for a specific weight key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight")
|
||||||
|
adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.)
|
||||||
|
strength: Multiplier for adapter effect
|
||||||
|
"""
|
||||||
|
# Remove .weight suffix if present for module lookup
|
||||||
|
module_key = key
|
||||||
|
if module_key.endswith(".weight"):
|
||||||
|
module_key = module_key[:-7]
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adapters[module_key] = (adapter, strength)
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_adapters(self):
|
||||||
|
"""Remove all adapters."""
|
||||||
|
self.adapters.clear()
|
||||||
|
|
||||||
|
def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]:
|
||||||
|
"""Get a submodule by dot-separated key."""
|
||||||
|
parts = key.split(".")
|
||||||
|
module = model
|
||||||
|
try:
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if part.isdigit():
|
||||||
|
module = module[int(part)]
|
||||||
|
else:
|
||||||
|
module = getattr(module, part)
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Found module for key {key}: {type(module).__name__}"
|
||||||
|
)
|
||||||
|
return module
|
||||||
|
except (AttributeError, IndexError, KeyError) as e:
|
||||||
|
logging.error(f"[BypassManager] Failed to find module for key {key}: {e}")
|
||||||
|
logging.error(
|
||||||
|
f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_injections(self, model: nn.Module) -> list[PatcherInjection]:
|
||||||
|
"""
|
||||||
|
Create PatcherInjection objects for all registered adapters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to inject into (e.g., model_patcher.model)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of PatcherInjection objects to use with model_patcher.set_injections()
|
||||||
|
"""
|
||||||
|
self.hooks.clear()
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
|
||||||
|
)
|
||||||
|
logging.debug(f"[BypassManager] Model type: {type(model).__name__}")
|
||||||
|
|
||||||
|
for key, (adapter, strength) in self.adapters.items():
|
||||||
|
logging.debug(f"[BypassManager] Looking for module: {key}")
|
||||||
|
module = self._get_module_by_key(model, key)
|
||||||
|
|
||||||
|
if module is None:
|
||||||
|
logging.warning(f"[BypassManager] Module not found for key {key}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not hasattr(module, "weight"):
|
||||||
|
logging.warning(
|
||||||
|
f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
|
||||||
|
)
|
||||||
|
hook = BypassForwardHook(module, adapter, multiplier=strength)
|
||||||
|
self.hooks.append(hook)
|
||||||
|
|
||||||
|
logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks")
|
||||||
|
|
||||||
|
# Create single injection that manages all hooks
|
||||||
|
def inject_all(model_patcher):
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
|
||||||
|
)
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.inject()
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Injected hook for {type(hook.module).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def eject_all(model_patcher):
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
|
||||||
|
)
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.eject()
|
||||||
|
|
||||||
|
return [PatcherInjection(inject=inject_all, eject=eject_all)]
|
||||||
|
|
||||||
|
def get_hook_count(self) -> int:
|
||||||
|
"""Return number of hooks that will be/are injected."""
|
||||||
|
return len(self.hooks)
|
||||||
|
|
||||||
|
|
||||||
|
def create_bypass_injections_from_patches(
|
||||||
|
model: nn.Module,
|
||||||
|
patches: dict,
|
||||||
|
strength: float = 1.0,
|
||||||
|
) -> list[PatcherInjection]:
|
||||||
|
"""
|
||||||
|
Convenience function to create bypass injections from a patches dict.
|
||||||
|
|
||||||
|
This is useful when you have patches in the format used by model_patcher.add_patches()
|
||||||
|
and want to apply them in bypass mode instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to inject into
|
||||||
|
patches: Dict mapping weight keys to adapter data
|
||||||
|
strength: Global strength multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of PatcherInjection objects
|
||||||
|
"""
|
||||||
|
manager = BypassInjectionManager()
|
||||||
|
|
||||||
|
for key, patch_list in patches.items():
|
||||||
|
if not patch_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# patches format: list of (strength_patch, patch_data, strength_model, offset, function)
|
||||||
|
for patch in patch_list:
|
||||||
|
patch_strength, patch_data, strength_model, offset, function = patch
|
||||||
|
|
||||||
|
# patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple
|
||||||
|
if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)):
|
||||||
|
adapter = patch_data
|
||||||
|
else:
|
||||||
|
# Skip non-adapter patches
|
||||||
|
continue
|
||||||
|
|
||||||
|
combined_strength = strength * patch_strength
|
||||||
|
manager.add_adapter(key, adapter, strength=combined_strength)
|
||||||
|
|
||||||
|
return manager.create_injections(model)
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
@@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase):
|
|||||||
b1_name = "{}.b1.weight".format(x)
|
b1_name = "{}.b1.weight".format(x)
|
||||||
b2_name = "{}.b2.weight".format(x)
|
b2_name = "{}.b2.weight".format(x)
|
||||||
if a1_name in lora:
|
if a1_name in lora:
|
||||||
weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
|
weights = (
|
||||||
|
lora[a1_name],
|
||||||
|
lora[a2_name],
|
||||||
|
lora[b1_name],
|
||||||
|
lora[b2_name],
|
||||||
|
alpha,
|
||||||
|
dora_scale,
|
||||||
|
)
|
||||||
loaded_keys.add(a1_name)
|
loaded_keys.add(a1_name)
|
||||||
loaded_keys.add(a2_name)
|
loaded_keys.add(a2_name)
|
||||||
loaded_keys.add(b1_name)
|
loaded_keys.add(b1_name)
|
||||||
@@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase):
|
|||||||
old_glora = True
|
old_glora = True
|
||||||
|
|
||||||
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
if (
|
||||||
|
old_glora
|
||||||
|
and v[1].shape[0] == weight.shape[0]
|
||||||
|
and weight.shape[0] == weight.shape[1]
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
old_glora = False
|
old_glora = False
|
||||||
rank = v[1].shape[0]
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a1 = comfy.model_management.cast_to_device(
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
v[0].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
)
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a2 = comfy.model_management.cast_to_device(
|
||||||
|
v[1].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
b1 = comfy.model_management.cast_to_device(
|
||||||
|
v[2].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
b2 = comfy.model_management.cast_to_device(
|
||||||
|
v[3].flatten(start_dim=1), weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
|
||||||
if v[4] is not None:
|
if v[4] is not None:
|
||||||
alpha = v[4] / rank
|
alpha = v[4] / rank
|
||||||
@@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if old_glora:
|
if old_glora:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
lora_diff = (
|
||||||
|
torch.mm(b2, b1)
|
||||||
|
+ torch.mm(
|
||||||
|
torch.mm(
|
||||||
|
weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2
|
||||||
|
),
|
||||||
|
a1,
|
||||||
|
)
|
||||||
|
).reshape(
|
||||||
|
weight.shape
|
||||||
|
) # old lycoris glora
|
||||||
else:
|
else:
|
||||||
if weight.dim() > 2:
|
if weight.dim() > 2:
|
||||||
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
lora_diff = torch.einsum(
|
||||||
|
"o i ..., i j -> o j ...",
|
||||||
|
torch.einsum(
|
||||||
|
"o i ..., i j -> o j ...",
|
||||||
|
weight.to(dtype=intermediate_dtype),
|
||||||
|
a1,
|
||||||
|
),
|
||||||
|
a2,
|
||||||
|
).reshape(weight.shape)
|
||||||
else:
|
else:
|
||||||
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
lora_diff = torch.mm(
|
||||||
|
torch.mm(weight.to(dtype=intermediate_dtype), a1), a2
|
||||||
|
).reshape(weight.shape)
|
||||||
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def _compute_paths(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Compute A path and B path outputs for GLoRA bypass.
|
||||||
|
|
||||||
|
GLoRA: f(x) = Wx + WAx + Bx
|
||||||
|
- A path: a1(a2(x)) - modifies input to base forward
|
||||||
|
- B path: b1(b2(x)) - additive component
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Does not access original model weights - bypass mode is designed
|
||||||
|
for quantized models where weights may not be accessible.
|
||||||
|
|
||||||
|
Returns: (a_out, b_out)
|
||||||
|
"""
|
||||||
|
v = self.weights
|
||||||
|
# v = (a1, a2, b1, b2, alpha, dora_scale)
|
||||||
|
a1 = v[0]
|
||||||
|
a2 = v[1]
|
||||||
|
b1 = v[2]
|
||||||
|
b2 = v[3]
|
||||||
|
alpha = v[4]
|
||||||
|
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
# Cast dtype (weights should already be on correct device from inject())
|
||||||
|
a1 = a1.to(dtype=dtype)
|
||||||
|
a2 = a2.to(dtype=dtype)
|
||||||
|
b1 = b1.to(dtype=dtype)
|
||||||
|
b2 = b2.to(dtype=dtype)
|
||||||
|
|
||||||
|
# Determine rank and scale
|
||||||
|
# Check for old vs new glora format
|
||||||
|
old_glora = False
|
||||||
|
if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]:
|
||||||
|
rank = a1.shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]:
|
||||||
|
if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = a2.shape[0]
|
||||||
|
|
||||||
|
if alpha is not None:
|
||||||
|
scale = alpha / rank
|
||||||
|
else:
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# Apply multiplier
|
||||||
|
multiplier = getattr(self, "multiplier", 1.0)
|
||||||
|
scale = scale * multiplier
|
||||||
|
|
||||||
|
# Use module info from bypass injection, not input tensor shape
|
||||||
|
is_conv = getattr(self, "is_conv", False)
|
||||||
|
conv_dim = getattr(self, "conv_dim", 0)
|
||||||
|
kw_dict = getattr(self, "kw_dict", {})
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
# Conv case - conv_dim is 1/2/3 for conv1d/2d/3d
|
||||||
|
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||||
|
|
||||||
|
# Get module's stride/padding for spatial dimension handling
|
||||||
|
module_stride = kw_dict.get("stride", (1,) * conv_dim)
|
||||||
|
module_padding = kw_dict.get("padding", (0,) * conv_dim)
|
||||||
|
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||||
|
in_channels = getattr(self, "in_channels", None)
|
||||||
|
|
||||||
|
# Ensure weights are in conv shape
|
||||||
|
# a1, a2, b1 are always 1x1 kernels
|
||||||
|
if a1.ndim == 2:
|
||||||
|
a1 = a1.view(*a1.shape, *([1] * conv_dim))
|
||||||
|
if a2.ndim == 2:
|
||||||
|
a2 = a2.view(*a2.shape, *([1] * conv_dim))
|
||||||
|
if b1.ndim == 2:
|
||||||
|
b1 = b1.view(*b1.shape, *([1] * conv_dim))
|
||||||
|
# b2 has actual kernel_size (like LoRA down)
|
||||||
|
if b2.ndim == 2:
|
||||||
|
if in_channels is not None:
|
||||||
|
b2 = b2.view(b2.shape[0], in_channels, *kernel_size)
|
||||||
|
else:
|
||||||
|
b2 = b2.view(*b2.shape, *([1] * conv_dim))
|
||||||
|
|
||||||
|
# A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x
|
||||||
|
a2_out = conv_fn(x, a2)
|
||||||
|
a_out = conv_fn(a2_out, a1) * scale
|
||||||
|
|
||||||
|
# B path: b2(x) with kernel/stride/padding -> b1(...) 1x1
|
||||||
|
b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding)
|
||||||
|
b_out = conv_fn(b2_out, b1) * scale
|
||||||
|
else:
|
||||||
|
# Linear case
|
||||||
|
if old_glora:
|
||||||
|
# Old format: a1 @ a2 @ x, b2 @ b1
|
||||||
|
a_out = F.linear(F.linear(x, a2), a1) * scale
|
||||||
|
b_out = F.linear(F.linear(x, b1), b2) * scale
|
||||||
|
else:
|
||||||
|
# New format: x @ a1 @ a2, b1 @ b2
|
||||||
|
a_out = F.linear(F.linear(x, a1), a2) * scale
|
||||||
|
b_out = F.linear(F.linear(x, b2), b1) * scale
|
||||||
|
|
||||||
|
return a_out, b_out
|
||||||
|
|
||||||
|
def bypass_forward(
|
||||||
|
self,
|
||||||
|
org_forward: Callable,
|
||||||
|
x: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GLoRA bypass forward: f(x + a(x)) + b(x)
|
||||||
|
|
||||||
|
Unlike standard adapters, GLoRA modifies the input to the base forward
|
||||||
|
AND adds the B path output.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Does not access original model weights - bypass mode is designed
|
||||||
|
for quantized models where weights may not be accessible.
|
||||||
|
|
||||||
|
Reference: LyCORIS GLoRAModule._bypass_forward
|
||||||
|
"""
|
||||||
|
a_out, b_out = self._compute_paths(x)
|
||||||
|
|
||||||
|
# Call base forward with modified input
|
||||||
|
base_out = org_forward(x + a_out, *args, **kwargs)
|
||||||
|
|
||||||
|
# Add B path
|
||||||
|
return base_out + b_out
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
For GLoRA, h() returns the B path output.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
GLoRA's full bypass requires overriding bypass_forward() since
|
||||||
|
it also modifies the input to org_forward. This h() is provided for
|
||||||
|
compatibility but bypass_forward() should be used for correct behavior.
|
||||||
|
|
||||||
|
Does not access original model weights - bypass mode is designed
|
||||||
|
for quantized models where weights may not be accessible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward (unused, for API consistency)
|
||||||
|
"""
|
||||||
|
_, b_out = self._compute_paths(x)
|
||||||
|
return b_out
|
||||||
|
|||||||
@@ -1,11 +1,22 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from functools import cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _warn_loha_bypass_inefficient():
|
||||||
|
"""One-time warning about LoHa bypass inefficiency."""
|
||||||
|
logging.warning(
|
||||||
|
"LoHa bypass mode is inefficient: full weight diff is computed each forward pass. "
|
||||||
|
"Consider using LoRA or LoKr for training with bypass mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HadaWeight(torch.autograd.Function):
|
class HadaWeight(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||||
@@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase):
|
|||||||
|
|
||||||
scale = self.alpha / self.rank
|
scale = self.alpha / self.rank
|
||||||
if self.use_tucker:
|
if self.use_tucker:
|
||||||
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
diff_weight = HadaWeightTucker.apply(
|
||||||
|
self.hada_t1,
|
||||||
|
self.hada_w1_a,
|
||||||
|
self.hada_w1_b,
|
||||||
|
self.hada_t2,
|
||||||
|
self.hada_w2_a,
|
||||||
|
self.hada_w2_b,
|
||||||
|
scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
diff_weight = HadaWeight.apply(
|
||||||
|
self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale
|
||||||
|
)
|
||||||
|
|
||||||
# Add the scaled difference to the original weight
|
# Add the scaled difference to the original weight
|
||||||
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||||
@@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||||
torch.nn.init.normal_(mat3, 0.1)
|
torch.nn.init.normal_(mat3, 0.1)
|
||||||
torch.nn.init.normal_(mat4, 0.01)
|
torch.nn.init.normal_(mat4, 0.01)
|
||||||
return LohaDiff(
|
return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None))
|
||||||
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_train(self):
|
def to_train(self):
|
||||||
return LohaDiff(self.weights)
|
return LohaDiff(self.weights)
|
||||||
@@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
loaded_keys.add(hada_t1_name)
|
loaded_keys.add(hada_t1_name)
|
||||||
loaded_keys.add(hada_t2_name)
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
|
weights = (
|
||||||
|
lora[hada_w1_a_name],
|
||||||
|
lora[hada_w1_b_name],
|
||||||
|
alpha,
|
||||||
|
lora[hada_w2_a_name],
|
||||||
|
lora[hada_w2_b_name],
|
||||||
|
hada_t1,
|
||||||
|
hada_t2,
|
||||||
|
dora_scale,
|
||||||
|
)
|
||||||
loaded_keys.add(hada_w1_a_name)
|
loaded_keys.add(hada_w1_a_name)
|
||||||
loaded_keys.add(hada_w1_b_name)
|
loaded_keys.add(hada_w1_b_name)
|
||||||
loaded_keys.add(hada_w2_a_name)
|
loaded_keys.add(hada_w2_a_name)
|
||||||
@@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
w2a = v[3]
|
w2a = v[3]
|
||||||
w2b = v[4]
|
w2b = v[4]
|
||||||
dora_scale = v[7]
|
dora_scale = v[7]
|
||||||
if v[5] is not None: #cp decomposition
|
if v[5] is not None: # cp decomposition
|
||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
m1 = torch.einsum(
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
"i j k l, j r, i p -> p r k l",
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
comfy.model_management.cast_to_device(
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
t1, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w1b, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w1a, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
m2 = torch.einsum(
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
"i j k l, j r, i p -> p r k l",
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
comfy.model_management.cast_to_device(
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
t2, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w2b, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w2a, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
m1 = torch.mm(
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
comfy.model_management.cast_to_device(
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
w1a, weight.device, intermediate_dtype
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w1b, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
m2 = torch.mm(
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w2a, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w2b, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component for LoHa: h(x) = diff_weight @ x
|
||||||
|
|
||||||
|
WARNING: Inefficient - computes full Hadamard product each forward.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Does not access original model weights - bypass mode is designed
|
||||||
|
for quantized models where weights may not be accessible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward (unused, for API consistency)
|
||||||
|
|
||||||
|
Reference: LyCORIS functional/loha.py bypass_forward_diff
|
||||||
|
"""
|
||||||
|
_warn_loha_bypass_inefficient()
|
||||||
|
|
||||||
|
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||||
|
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||||
|
|
||||||
|
v = self.weights
|
||||||
|
# v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
|
||||||
|
# Compute scale
|
||||||
|
rank = w1b.shape[0]
|
||||||
|
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
|
||||||
|
self, "multiplier", 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cast dtype
|
||||||
|
w1a = w1a.to(dtype=x.dtype)
|
||||||
|
w1b = w1b.to(dtype=x.dtype)
|
||||||
|
w2a = w2a.to(dtype=x.dtype)
|
||||||
|
w2b = w2b.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
# Use module info from bypass injection, not weight dimension
|
||||||
|
is_conv = getattr(self, "is_conv", False)
|
||||||
|
conv_dim = getattr(self, "conv_dim", 0)
|
||||||
|
kw_dict = getattr(self, "kw_dict", {})
|
||||||
|
|
||||||
|
# Compute diff weight using Hadamard product
|
||||||
|
if t1 is not None and t2 is not None:
|
||||||
|
t1 = t1.to(dtype=x.dtype)
|
||||||
|
t2 = t2.to(dtype=x.dtype)
|
||||||
|
m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a)
|
||||||
|
m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a)
|
||||||
|
diff_weight = (m1 * m2) * scale
|
||||||
|
else:
|
||||||
|
m1 = w1a @ w1b
|
||||||
|
m2 = w2a @ w2b
|
||||||
|
diff_weight = (m1 * m2) * scale
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
op = FUNC_LIST[conv_dim + 2]
|
||||||
|
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||||
|
in_channels = getattr(self, "in_channels", None)
|
||||||
|
|
||||||
|
# Reshape 2D diff_weight to conv format using kernel_size
|
||||||
|
# diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size]
|
||||||
|
if diff_weight.dim() == 2:
|
||||||
|
if in_channels is not None:
|
||||||
|
diff_weight = diff_weight.view(
|
||||||
|
diff_weight.shape[0], in_channels, *kernel_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
diff_weight = diff_weight.view(
|
||||||
|
*diff_weight.shape, *([1] * conv_dim)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
op = F.linear
|
||||||
|
kw_dict = {}
|
||||||
|
|
||||||
|
return op(x, diff_weight, **kw_dict)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import (
|
from .base import (
|
||||||
WeightAdapterBase,
|
WeightAdapterBase,
|
||||||
@@ -14,7 +15,17 @@ from .base import (
|
|||||||
class LokrDiff(WeightAdapterTrainBase):
|
class LokrDiff(WeightAdapterTrainBase):
|
||||||
def __init__(self, weights):
|
def __init__(self, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
(
|
||||||
|
lokr_w1,
|
||||||
|
lokr_w2,
|
||||||
|
alpha,
|
||||||
|
lokr_w1_a,
|
||||||
|
lokr_w1_b,
|
||||||
|
lokr_w2_a,
|
||||||
|
lokr_w2_b,
|
||||||
|
lokr_t2,
|
||||||
|
dora_scale,
|
||||||
|
) = weights
|
||||||
self.use_tucker = False
|
self.use_tucker = False
|
||||||
if lokr_w1_a is not None:
|
if lokr_w1_a is not None:
|
||||||
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||||
@@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase):
|
|||||||
if self.w2_rebuild:
|
if self.w2_rebuild:
|
||||||
if self.use_tucker:
|
if self.use_tucker:
|
||||||
w2 = torch.einsum(
|
w2 = torch.einsum(
|
||||||
'i j k l, j r, i p -> p r k l',
|
"i j k l, j r, i p -> p r k l",
|
||||||
self.lokr_t2,
|
self.lokr_t2,
|
||||||
self.lokr_w2_b,
|
self.lokr_w2_b,
|
||||||
self.lokr_w2_a
|
self.lokr_w2_a,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||||
@@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase):
|
|||||||
return self.lokr_w2
|
return self.lokr_w2
|
||||||
|
|
||||||
def __call__(self, w):
|
def __call__(self, w):
|
||||||
diff = torch.kron(self.w1, self.w2)
|
w1 = self.w1
|
||||||
|
w2 = self.w2
|
||||||
|
# Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron)
|
||||||
|
for _ in range(w2.dim() - w1.dim()):
|
||||||
|
w1 = w1.unsqueeze(-1)
|
||||||
|
diff = torch.kron(w1, w2)
|
||||||
return w + diff.reshape(w.shape).to(w)
|
return w + diff.reshape(w.shape).to(w)
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component for LoKr training: efficient Kronecker product.
|
||||||
|
|
||||||
|
Uses w1/w2 properties which handle both direct and decomposed cases.
|
||||||
|
For create_train (direct w1/w2), no alpha scaling in properties.
|
||||||
|
For to_train (decomposed), alpha/rank scaling is in properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward (unused, for API consistency)
|
||||||
|
"""
|
||||||
|
# Get w1, w2 from properties (handles rebuild vs direct)
|
||||||
|
w1 = self.w1
|
||||||
|
w2 = self.w2
|
||||||
|
|
||||||
|
# Multiplier from bypass injection
|
||||||
|
multiplier = getattr(self, "multiplier", 1.0)
|
||||||
|
|
||||||
|
# Get module info from bypass injection
|
||||||
|
is_conv = getattr(self, "is_conv", False)
|
||||||
|
conv_dim = getattr(self, "conv_dim", 0)
|
||||||
|
kw_dict = getattr(self, "kw_dict", {})
|
||||||
|
|
||||||
|
# Efficient Kronecker application without materializing full weight
|
||||||
|
# kron(w1, w2) @ x can be computed as nested operations
|
||||||
|
# w1: [out_l, in_m], w2: [out_k, in_n, *k_size]
|
||||||
|
# Full weight would be [out_l*out_k, in_m*in_n, *k_size]
|
||||||
|
|
||||||
|
uq = w1.size(1) # in_m - inner grouping dimension
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||||
|
|
||||||
|
B, C_in, *spatial = x.shape
|
||||||
|
# Reshape input for grouped application: [B * uq, C_in // uq, *spatial]
|
||||||
|
h_in_group = x.reshape(B * uq, -1, *spatial)
|
||||||
|
|
||||||
|
# Ensure w2 has conv dims
|
||||||
|
if w2.dim() == 2:
|
||||||
|
w2 = w2.view(*w2.shape, *([1] * conv_dim))
|
||||||
|
|
||||||
|
# Apply w2 path with stride/padding
|
||||||
|
hb = conv_fn(h_in_group, w2, **kw_dict)
|
||||||
|
|
||||||
|
# Reshape for cross-group operation
|
||||||
|
hb = hb.view(B, -1, *hb.shape[1:])
|
||||||
|
h_cross = hb.transpose(1, -1)
|
||||||
|
|
||||||
|
# Apply w1 (always 2D, applied as linear on channel dim)
|
||||||
|
hc = F.linear(h_cross, w1)
|
||||||
|
hc = hc.transpose(1, -1)
|
||||||
|
|
||||||
|
# Reshape to output
|
||||||
|
out = hc.reshape(B, -1, *hc.shape[3:])
|
||||||
|
else:
|
||||||
|
# Linear case
|
||||||
|
# Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n]
|
||||||
|
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
|
||||||
|
|
||||||
|
# Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k]
|
||||||
|
hb = F.linear(h_in_group, w2)
|
||||||
|
|
||||||
|
# Transpose for w1: [..., uq, out_k] -> [..., out_k, uq]
|
||||||
|
h_cross = hb.transpose(-1, -2)
|
||||||
|
|
||||||
|
# Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l]
|
||||||
|
hc = F.linear(h_cross, w1)
|
||||||
|
|
||||||
|
# Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k]
|
||||||
|
hc = hc.transpose(-1, -2)
|
||||||
|
out = hc.reshape(*hc.shape[:-2], -1)
|
||||||
|
|
||||||
|
return out * multiplier
|
||||||
|
|
||||||
def passive_memory_usage(self):
|
def passive_memory_usage(self):
|
||||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
@@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
out_dim = weight.shape[0]
|
out_dim = weight.shape[0]
|
||||||
in_dim = weight.shape[1:].numel()
|
in_dim = weight.shape[1] # Just in_channels, not flattened with kernel
|
||||||
out1, out2 = factorization(out_dim, rank)
|
k_size = weight.shape[2:] if weight.dim() > 2 else ()
|
||||||
in1, in2 = factorization(in_dim, rank)
|
|
||||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
|
out_l, out_k = factorization(out_dim, rank)
|
||||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
|
in_m, in_n = factorization(in_dim, rank)
|
||||||
|
|
||||||
|
# w1: [out_l, in_m]
|
||||||
|
mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32)
|
||||||
|
# w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear
|
||||||
|
mat2 = torch.empty(
|
||||||
|
out_k, in_n, *k_size, device=weight.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||||
torch.nn.init.constant_(mat1, 0.0)
|
torch.nn.init.constant_(mat1, 0.0)
|
||||||
return LokrDiff(
|
return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None))
|
||||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_train(self):
|
def to_train(self):
|
||||||
return LokrDiff(self.weights)
|
return LokrDiff(self.weights)
|
||||||
@@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
lokr_t2 = lora[lokr_t2_name]
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
loaded_keys.add(lokr_t2_name)
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
if (
|
||||||
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
(lokr_w1 is not None)
|
||||||
|
or (lokr_w2 is not None)
|
||||||
|
or (lokr_w1_a is not None)
|
||||||
|
or (lokr_w2_a is not None)
|
||||||
|
):
|
||||||
|
weights = (
|
||||||
|
lokr_w1,
|
||||||
|
lokr_w2,
|
||||||
|
alpha,
|
||||||
|
lokr_w1_a,
|
||||||
|
lokr_w1_b,
|
||||||
|
lokr_w2_a,
|
||||||
|
lokr_w2_b,
|
||||||
|
lokr_t2,
|
||||||
|
dora_scale,
|
||||||
|
)
|
||||||
return cls(loaded_keys, weights)
|
return cls(loaded_keys, weights)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
|
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
w1 = torch.mm(
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
comfy.model_management.cast_to_device(
|
||||||
|
w1_a, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w1_b, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
w1 = comfy.model_management.cast_to_device(
|
||||||
|
w1, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
|
||||||
if w2 is None:
|
if w2 is None:
|
||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
w2 = torch.mm(
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
comfy.model_management.cast_to_device(
|
||||||
|
w2_a, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w2_b, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
w2 = torch.einsum(
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
"i j k l, j r, i p -> p r k l",
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
comfy.model_management.cast_to_device(
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
t2, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w2_b, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
comfy.model_management.cast_to_device(
|
||||||
|
w2_a, weight.device, intermediate_dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
w2 = comfy.model_management.cast_to_device(
|
||||||
|
w2, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
@@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
try:
|
try:
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component for LoKr: efficient Kronecker product application.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Does not access original model weights - bypass mode is designed
|
||||||
|
for quantized models where weights may not be accessible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward (unused, for API consistency)
|
||||||
|
|
||||||
|
Reference: LyCORIS functional/lokr.py bypass_forward_diff
|
||||||
|
"""
|
||||||
|
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||||
|
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||||
|
|
||||||
|
v = self.weights
|
||||||
|
# v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
|
||||||
|
use_w1 = w1 is not None
|
||||||
|
use_w2 = w2 is not None
|
||||||
|
tucker = t2 is not None
|
||||||
|
|
||||||
|
# Use module info from bypass injection, not weight dimension
|
||||||
|
is_conv = getattr(self, "is_conv", False)
|
||||||
|
conv_dim = getattr(self, "conv_dim", 0)
|
||||||
|
kw_dict = getattr(self, "kw_dict", {}) if is_conv else {}
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
op = FUNC_LIST[conv_dim + 2]
|
||||||
|
else:
|
||||||
|
op = F.linear
|
||||||
|
|
||||||
|
# Determine rank and scale
|
||||||
|
rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha
|
||||||
|
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
|
||||||
|
self, "multiplier", 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build c (w1)
|
||||||
|
if use_w1:
|
||||||
|
c = w1.to(dtype=x.dtype)
|
||||||
|
else:
|
||||||
|
c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype)
|
||||||
|
uq = c.size(1)
|
||||||
|
|
||||||
|
# Build w2 components
|
||||||
|
if use_w2:
|
||||||
|
ba = w2.to(dtype=x.dtype)
|
||||||
|
else:
|
||||||
|
a = w2_b.to(dtype=x.dtype)
|
||||||
|
b = w2_a.to(dtype=x.dtype)
|
||||||
|
if is_conv:
|
||||||
|
if tucker:
|
||||||
|
# Tucker: a, b get 1s appended (kernel is in t2)
|
||||||
|
if a.dim() == 2:
|
||||||
|
a = a.view(*a.shape, *([1] * conv_dim))
|
||||||
|
if b.dim() == 2:
|
||||||
|
b = b.view(*b.shape, *([1] * conv_dim))
|
||||||
|
else:
|
||||||
|
# Non-tucker conv: b may need 1s appended
|
||||||
|
if b.dim() == 2:
|
||||||
|
b = b.view(*b.shape, *([1] * conv_dim))
|
||||||
|
|
||||||
|
# Reshape input by uq groups
|
||||||
|
if is_conv:
|
||||||
|
B, _, *rest = x.shape
|
||||||
|
h_in_group = x.reshape(B * uq, -1, *rest)
|
||||||
|
else:
|
||||||
|
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
|
||||||
|
|
||||||
|
# Apply w2 path
|
||||||
|
if use_w2:
|
||||||
|
hb = op(h_in_group, ba, **kw_dict)
|
||||||
|
else:
|
||||||
|
if is_conv:
|
||||||
|
if tucker:
|
||||||
|
t = t2.to(dtype=x.dtype)
|
||||||
|
if t.dim() == 2:
|
||||||
|
t = t.view(*t.shape, *([1] * conv_dim))
|
||||||
|
ha = op(h_in_group, a)
|
||||||
|
ht = op(ha, t, **kw_dict)
|
||||||
|
hb = op(ht, b)
|
||||||
|
else:
|
||||||
|
ha = op(h_in_group, a, **kw_dict)
|
||||||
|
hb = op(ha, b)
|
||||||
|
else:
|
||||||
|
ha = op(h_in_group, a)
|
||||||
|
hb = op(ha, b)
|
||||||
|
|
||||||
|
# Reshape and apply c (w1)
|
||||||
|
if is_conv:
|
||||||
|
hb = hb.view(B, -1, *hb.shape[1:])
|
||||||
|
h_cross_group = hb.transpose(1, -1)
|
||||||
|
else:
|
||||||
|
h_cross_group = hb.transpose(-1, -2)
|
||||||
|
|
||||||
|
hc = F.linear(h_cross_group, c)
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
hc = hc.transpose(1, -1)
|
||||||
|
out = hc.reshape(B, -1, *hc.shape[3:])
|
||||||
|
else:
|
||||||
|
hc = hc.transpose(-1, -2)
|
||||||
|
out = hc.reshape(*hc.shape[:-2], -1)
|
||||||
|
|
||||||
|
return out * scale
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import (
|
from .base import (
|
||||||
WeightAdapterBase,
|
WeightAdapterBase,
|
||||||
@@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase):
|
|||||||
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
rank, in_dim = mat2.shape[0], mat2.shape[1]
|
||||||
if mid is not None:
|
if mid is not None:
|
||||||
convdim = mid.ndim - 2
|
convdim = mid.ndim - 2
|
||||||
layer = (
|
layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim]
|
||||||
torch.nn.Conv1d,
|
|
||||||
torch.nn.Conv2d,
|
|
||||||
torch.nn.Conv3d
|
|
||||||
)[convdim]
|
|
||||||
else:
|
else:
|
||||||
layer = torch.nn.Linear
|
layer = torch.nn.Linear
|
||||||
self.lora_up = layer(rank, out_dim, bias=False)
|
self.lora_up = layer(rank, out_dim, bias=False)
|
||||||
@@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase):
|
|||||||
weight = w + scale * diff.reshape(w.shape)
|
weight = w + scale * diff.reshape(w.shape)
|
||||||
return weight.to(org_dtype)
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component for LoRA training: h(x) = up(down(x)) * scale
|
||||||
|
|
||||||
|
Simple implementation using the nn.Module weights directly.
|
||||||
|
No mid/dora/reshape branches (create_train doesn't create them).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward (unused, for API consistency)
|
||||||
|
"""
|
||||||
|
# Compute scale = alpha / rank * multiplier
|
||||||
|
scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0)
|
||||||
|
|
||||||
|
# Get module info from bypass injection
|
||||||
|
is_conv = getattr(self, "is_conv", False)
|
||||||
|
conv_dim = getattr(self, "conv_dim", 0)
|
||||||
|
kw_dict = getattr(self, "kw_dict", {})
|
||||||
|
|
||||||
|
# Get weights (keep in original dtype for numerical stability)
|
||||||
|
down_weight = self.lora_down.weight
|
||||||
|
up_weight = self.lora_up.weight
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
# Conv path: use functional conv
|
||||||
|
# conv_dim: 1=conv1d, 2=conv2d, 3=conv3d
|
||||||
|
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||||
|
|
||||||
|
# Reshape 2D weights to conv format if needed
|
||||||
|
# down: [rank, in_features] -> [rank, in_channels, *kernel_size]
|
||||||
|
# up: [out_features, rank] -> [out_features, rank, 1, 1, ...]
|
||||||
|
if down_weight.dim() == 2:
|
||||||
|
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||||
|
in_channels = getattr(self, "in_channels", None)
|
||||||
|
if in_channels is not None:
|
||||||
|
down_weight = down_weight.view(
|
||||||
|
down_weight.shape[0], in_channels, *kernel_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback: assume 1x1 kernel
|
||||||
|
down_weight = down_weight.view(
|
||||||
|
*down_weight.shape, *([1] * conv_dim)
|
||||||
|
)
|
||||||
|
if up_weight.dim() == 2:
|
||||||
|
# up always uses 1x1 kernel
|
||||||
|
up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim))
|
||||||
|
|
||||||
|
# down conv uses stride/padding from module, up is 1x1
|
||||||
|
hidden = conv_fn(x, down_weight, **kw_dict)
|
||||||
|
|
||||||
|
# mid layer if exists (tucker decomposition)
|
||||||
|
if self.lora_mid is not None:
|
||||||
|
mid_weight = self.lora_mid.weight
|
||||||
|
if mid_weight.dim() == 2:
|
||||||
|
mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim))
|
||||||
|
hidden = conv_fn(hidden, mid_weight)
|
||||||
|
|
||||||
|
# up conv is always 1x1 (no stride/padding)
|
||||||
|
out = conv_fn(hidden, up_weight)
|
||||||
|
else:
|
||||||
|
# Linear path: simple matmul chain
|
||||||
|
hidden = F.linear(x, down_weight)
|
||||||
|
|
||||||
|
# mid layer if exists
|
||||||
|
if self.lora_mid is not None:
|
||||||
|
mid_weight = self.lora_mid.weight
|
||||||
|
hidden = F.linear(hidden, mid_weight)
|
||||||
|
|
||||||
|
out = F.linear(hidden, up_weight)
|
||||||
|
|
||||||
|
return out * scale
|
||||||
|
|
||||||
def passive_memory_usage(self):
|
def passive_memory_usage(self):
|
||||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
@@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32)
|
||||||
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
|
||||||
torch.nn.init.constant_(mat2, 0.0)
|
torch.nn.init.constant_(mat2, 0.0)
|
||||||
return LoraDiff(
|
return LoraDiff((mat1, mat2, alpha, None, None, None))
|
||||||
(mat1, mat2, alpha, None, None, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_train(self):
|
def to_train(self):
|
||||||
return LoraDiff(self.weights)
|
return LoraDiff(self.weights)
|
||||||
@@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component for LoRA: h(x) = up(down(x)) * scale
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Does not access original model weights - bypass mode is designed
|
||||||
|
for quantized models where weights may not be accessible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward (unused, for API consistency)
|
||||||
|
|
||||||
|
Reference: LyCORIS functional/locon.py bypass_forward_diff
|
||||||
|
"""
|
||||||
|
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||||
|
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||||
|
|
||||||
|
v = self.weights
|
||||||
|
# v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape
|
||||||
|
up = v[0]
|
||||||
|
down = v[1]
|
||||||
|
alpha = v[2]
|
||||||
|
mid = v[3]
|
||||||
|
|
||||||
|
# Compute scale = alpha / rank
|
||||||
|
rank = down.shape[0]
|
||||||
|
if alpha is not None:
|
||||||
|
scale = alpha / rank
|
||||||
|
else:
|
||||||
|
scale = 1.0
|
||||||
|
scale = scale * getattr(self, "multiplier", 1.0)
|
||||||
|
|
||||||
|
# Cast dtype
|
||||||
|
up = up.to(dtype=x.dtype)
|
||||||
|
down = down.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
# Use module info from bypass injection, not weight dimension
|
||||||
|
is_conv = getattr(self, "is_conv", False)
|
||||||
|
conv_dim = getattr(self, "conv_dim", 0)
|
||||||
|
kw_dict = getattr(self, "kw_dict", {})
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
op = FUNC_LIST[
|
||||||
|
conv_dim + 2
|
||||||
|
] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5)
|
||||||
|
kernel_size = getattr(self, "kernel_size", (1,) * conv_dim)
|
||||||
|
in_channels = getattr(self, "in_channels", None)
|
||||||
|
|
||||||
|
# Reshape 2D weights to conv format using kernel_size
|
||||||
|
# down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size]
|
||||||
|
# up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel)
|
||||||
|
if down.dim() == 2:
|
||||||
|
# down.shape[1] = in_channels * prod(kernel_size)
|
||||||
|
if in_channels is not None:
|
||||||
|
down = down.view(down.shape[0], in_channels, *kernel_size)
|
||||||
|
else:
|
||||||
|
# Fallback: assume 1x1 kernel if in_channels unknown
|
||||||
|
down = down.view(*down.shape, *([1] * conv_dim))
|
||||||
|
if up.dim() == 2:
|
||||||
|
# up always uses 1x1 kernel
|
||||||
|
up = up.view(*up.shape, *([1] * conv_dim))
|
||||||
|
if mid is not None:
|
||||||
|
mid = mid.to(dtype=x.dtype)
|
||||||
|
if mid.dim() == 2:
|
||||||
|
mid = mid.view(*mid.shape, *([1] * conv_dim))
|
||||||
|
else:
|
||||||
|
op = F.linear
|
||||||
|
kw_dict = {} # linear doesn't take stride/padding
|
||||||
|
|
||||||
|
# Simple chain: down -> mid (if tucker) -> up
|
||||||
|
if mid is not None:
|
||||||
|
if not is_conv:
|
||||||
|
mid = mid.to(dtype=x.dtype)
|
||||||
|
hidden = op(x, down)
|
||||||
|
hidden = op(hidden, mid, **kw_dict)
|
||||||
|
out = op(hidden, up)
|
||||||
|
else:
|
||||||
|
hidden = op(x, down, **kw_dict)
|
||||||
|
out = op(hidden, up)
|
||||||
|
|
||||||
|
return out * scale
|
||||||
|
|||||||
@@ -3,13 +3,18 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
from .base import (
|
||||||
|
WeightAdapterBase,
|
||||||
|
WeightAdapterTrainBase,
|
||||||
|
weight_decompose,
|
||||||
|
factorization,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OFTDiff(WeightAdapterTrainBase):
|
class OFTDiff(WeightAdapterTrainBase):
|
||||||
def __init__(self, weights):
|
def __init__(self, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Unpack weights tuple from LoHaAdapter
|
# Unpack weights tuple from OFTAdapter
|
||||||
blocks, rescale, alpha, _ = weights
|
blocks, rescale, alpha, _ = weights
|
||||||
|
|
||||||
# Create trainable parameters
|
# Create trainable parameters
|
||||||
@@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase):
|
|||||||
weight = self.rescale * weight
|
weight = self.rescale * weight
|
||||||
return weight.to(org_dtype)
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def _get_orthogonal_matrix(self, device, dtype):
|
||||||
|
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
||||||
|
blocks = self.oft_blocks.to(device=device, dtype=dtype)
|
||||||
|
I = torch.eye(self.block_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Q = blocks - blocks^T (skew-symmetric)
|
||||||
|
q = blocks - blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
|
||||||
|
# Apply constraint if set
|
||||||
|
if self.constraint:
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > self.constraint:
|
||||||
|
normed_q = q * self.constraint / q_norm
|
||||||
|
|
||||||
|
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
return r.to(dtype)
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
OFT has no additive component - returns zeros matching base_out shape.
|
||||||
|
|
||||||
|
OFT only transforms the output via g(), it doesn't add to it.
|
||||||
|
"""
|
||||||
|
return torch.zeros_like(base_out)
|
||||||
|
|
||||||
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Output transformation for OFT: applies orthogonal rotation.
|
||||||
|
|
||||||
|
OFT transforms output channels using block-diagonal orthogonal matrices.
|
||||||
|
"""
|
||||||
|
r = self._get_orthogonal_matrix(y.device, y.dtype)
|
||||||
|
|
||||||
|
# Apply multiplier to interpolate between identity and full transform
|
||||||
|
multiplier = getattr(self, "multiplier", 1.0)
|
||||||
|
I = torch.eye(self.block_size, device=y.device, dtype=y.dtype)
|
||||||
|
r = r * multiplier + (1 - multiplier) * I
|
||||||
|
|
||||||
|
# Use module info from bypass injection
|
||||||
|
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||||
|
y = y.transpose(1, -1)
|
||||||
|
|
||||||
|
# y now has channels in last dim
|
||||||
|
*batch_shape, out_features = y.shape
|
||||||
|
|
||||||
|
# Reshape to apply block-diagonal transform
|
||||||
|
# (*, out_features) -> (*, block_num, block_size)
|
||||||
|
y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size)
|
||||||
|
|
||||||
|
# Apply orthogonal transform: R @ y for each block
|
||||||
|
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
||||||
|
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
||||||
|
|
||||||
|
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
||||||
|
out = out_blocked.reshape(*batch_shape, out_features)
|
||||||
|
|
||||||
|
# Apply rescale if present
|
||||||
|
if self.rescaled:
|
||||||
|
rescale = self.rescale.to(device=y.device, dtype=y.dtype)
|
||||||
|
out = out * rescale.view(-1)
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||||
|
out = out.transpose(1, -1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
def passive_memory_usage(self):
|
def passive_memory_usage(self):
|
||||||
"""Calculates memory usage of the trainable parameters."""
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
@@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
out_dim = weight.shape[0]
|
out_dim = weight.shape[0]
|
||||||
block_size, block_num = factorization(out_dim, rank)
|
block_size, block_num = factorization(out_dim, rank)
|
||||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32)
|
block = torch.zeros(
|
||||||
return OFTDiff(
|
block_num, block_size, block_size, device=weight.device, dtype=torch.float32
|
||||||
(block, None, alpha, None)
|
|
||||||
)
|
)
|
||||||
|
return OFTDiff((block, None, alpha, None))
|
||||||
|
|
||||||
def to_train(self):
|
def to_train(self):
|
||||||
return OFTDiff(self.weights)
|
return OFTDiff(self.weights)
|
||||||
@@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
alpha = 0
|
alpha = 0
|
||||||
dora_scale = v[3]
|
dora_scale = v[3]
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
blocks = comfy.model_management.cast_to_device(
|
||||||
|
blocks, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
if rescale is not None:
|
if rescale is not None:
|
||||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
rescale = comfy.model_management.cast_to_device(
|
||||||
|
rescale, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
|
|
||||||
block_num, block_size, *_ = blocks.shape
|
block_num, block_size, *_ = blocks.shape
|
||||||
|
|
||||||
@@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
# for Q = -Q^T
|
# for Q = -Q^T
|
||||||
q = blocks - blocks.transpose(1, 2)
|
q = blocks - blocks.transpose(1, 2)
|
||||||
normed_q = q
|
normed_q = q
|
||||||
if alpha > 0: # alpha in oft/boft is for constraint
|
if alpha > 0: # alpha in oft/boft is for constraint
|
||||||
q_norm = torch.norm(q) + 1e-8
|
q_norm = torch.norm(q) + 1e-8
|
||||||
if q_norm > alpha:
|
if q_norm > alpha:
|
||||||
normed_q = q * alpha / q_norm
|
normed_q = q * alpha / q_norm
|
||||||
# use float() to prevent unsupported type in .inverse()
|
# use float() to prevent unsupported type in .inverse()
|
||||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
r = r.to(weight)
|
r = r.to(weight)
|
||||||
|
# Create I in weight's dtype for the einsum
|
||||||
|
I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype)
|
||||||
_, *shape = weight.shape
|
_, *shape = weight.shape
|
||||||
lora_diff = torch.einsum(
|
lora_diff = torch.einsum(
|
||||||
"k n m, k n ... -> k m ...",
|
"k n m, k n ... -> k m ...",
|
||||||
(r * strength) - strength * I,
|
(r * strength) - strength * I_w,
|
||||||
weight.view(block_num, block_size, *shape),
|
weight.view(block_num, block_size, *shape),
|
||||||
).view(-1, *shape)
|
).view(-1, *shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
weight = weight_decompose(
|
||||||
|
dora_scale,
|
||||||
|
weight,
|
||||||
|
lora_diff,
|
||||||
|
alpha,
|
||||||
|
strength,
|
||||||
|
intermediate_dtype,
|
||||||
|
function,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weight += function((strength * lora_diff).type(weight.dtype))
|
weight += function((strength * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def _get_orthogonal_matrix(self, device, dtype):
|
||||||
|
"""Compute the orthogonal rotation matrix R from OFT blocks."""
|
||||||
|
v = self.weights
|
||||||
|
blocks = v[0].to(device=device, dtype=dtype)
|
||||||
|
alpha = v[2]
|
||||||
|
if alpha is None:
|
||||||
|
alpha = 0
|
||||||
|
|
||||||
|
block_num, block_size, _ = blocks.shape
|
||||||
|
I = torch.eye(block_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Q = blocks - blocks^T (skew-symmetric)
|
||||||
|
q = blocks - blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
|
||||||
|
# Apply constraint if alpha > 0
|
||||||
|
if alpha > 0:
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > alpha:
|
||||||
|
normed_q = q * alpha / q_norm
|
||||||
|
|
||||||
|
# Cayley transform: R = (I + Q)(I - Q)^-1
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
return r, block_num, block_size
|
||||||
|
|
||||||
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Output transformation for OFT: applies orthogonal rotation to output.
|
||||||
|
|
||||||
|
OFT transforms the output channels using block-diagonal orthogonal matrices.
|
||||||
|
|
||||||
|
Reference: LyCORIS DiagOFTModule._bypass_forward
|
||||||
|
"""
|
||||||
|
v = self.weights
|
||||||
|
rescale = v[1]
|
||||||
|
|
||||||
|
r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype)
|
||||||
|
|
||||||
|
# Apply multiplier to interpolate between identity and full transform
|
||||||
|
multiplier = getattr(self, "multiplier", 1.0)
|
||||||
|
I = torch.eye(block_size, device=y.device, dtype=y.dtype)
|
||||||
|
r = r * multiplier + (1 - multiplier) * I
|
||||||
|
|
||||||
|
# Use module info from bypass injection to determine conv vs linear
|
||||||
|
is_conv = getattr(self, "is_conv", y.dim() > 2)
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
# Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C)
|
||||||
|
y = y.transpose(1, -1)
|
||||||
|
|
||||||
|
# y now has channels in last dim
|
||||||
|
*batch_shape, out_features = y.shape
|
||||||
|
|
||||||
|
# Reshape to apply block-diagonal transform
|
||||||
|
# (*, out_features) -> (*, block_num, block_size)
|
||||||
|
y_blocked = y.view(*batch_shape, block_num, block_size)
|
||||||
|
|
||||||
|
# Apply orthogonal transform: R @ y for each block
|
||||||
|
# r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size)
|
||||||
|
out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked)
|
||||||
|
|
||||||
|
# Reshape back: (*, block_num, block_size) -> (*, out_features)
|
||||||
|
out = out_blocked.view(*batch_shape, out_features)
|
||||||
|
|
||||||
|
# Apply rescale if present
|
||||||
|
if rescale is not None:
|
||||||
|
rescale = rescale.to(device=y.device, dtype=y.dtype)
|
||||||
|
out = out * rescale.view(-1)
|
||||||
|
|
||||||
|
if is_conv:
|
||||||
|
# Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...)
|
||||||
|
out = out.transpose(1, -1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy.weight_adapter import adapters, adapter_maps
|
from comfy.weight_adapter import adapters, adapter_maps
|
||||||
|
from comfy.weight_adapter.bypass import BypassInjectionManager
|
||||||
from comfy_api.latest import ComfyExtension, io, ui
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
|
|
||||||
@@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||||
|
|
||||||
if (i + 1) % self.grad_acc == 0:
|
if (i + 1) % self.grad_acc == 0:
|
||||||
|
for param_groups in self.optimizer.param_groups:
|
||||||
|
for param in param_groups["params"]:
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
ui_pbar.update(1)
|
ui_pbar.update(1)
|
||||||
@@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
|||||||
num_images = sum(t.shape[0] for t in latents)
|
num_images = sum(t.shape[0] for t in latents)
|
||||||
multi_res = False # Not using multi_res path in bucket mode
|
multi_res = False # Not using multi_res path in bucket mode
|
||||||
|
|
||||||
logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples")
|
||||||
for i, lat in enumerate(latents):
|
for i, lat in enumerate(latents):
|
||||||
logging.info(f" Bucket {i}: shape {lat.shape}")
|
logging.debug(f" Bucket {i}: shape {lat.shape}")
|
||||||
return latents, num_images, multi_res
|
return latents, num_images, multi_res
|
||||||
|
|
||||||
# Non-bucket mode
|
# Non-bucket mode
|
||||||
@@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode):
|
|||||||
latents = [t.to(dtype) for t in latents]
|
latents = [t.to(dtype) for t in latents]
|
||||||
for latent in latents:
|
for latent in latents:
|
||||||
all_shapes.add(latent.shape)
|
all_shapes.add(latent.shape)
|
||||||
logging.info(f"Latent shapes: {all_shapes}")
|
logging.debug(f"Latent shapes: {all_shapes}")
|
||||||
if len(all_shapes) > 1:
|
if len(all_shapes) > 1:
|
||||||
multi_res = True
|
multi_res = True
|
||||||
else:
|
else:
|
||||||
@@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode):
|
|||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
return positive # Skip validation in bucket mode
|
return positive # Skip validation in bucket mode
|
||||||
|
|
||||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||||
if len(positive) == 1 and num_images > 1:
|
if len(positive) == 1 and num_images > 1:
|
||||||
return positive * num_images
|
return positive * num_images
|
||||||
elif len(positive) != num_images:
|
elif len(positive) != num_images:
|
||||||
@@ -596,6 +602,8 @@ def _create_weight_adapter(
|
|||||||
shape = module.weight.shape
|
shape = module.weight.shape
|
||||||
lora_params = {}
|
lora_params = {}
|
||||||
|
|
||||||
|
logging.debug(f"Creating weight adapter for {key} with shape {shape}")
|
||||||
|
|
||||||
if len(shape) >= 2:
|
if len(shape) >= 2:
|
||||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||||
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
dora_scale = existing_weights.get(f"{key}.dora_scale", None)
|
||||||
@@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank):
|
|||||||
return lora_sd, all_weight_adapters
|
return lora_sd, all_weight_adapters
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank):
|
||||||
|
"""Setup LoRA adapters in bypass mode.
|
||||||
|
|
||||||
|
In bypass mode:
|
||||||
|
- Weight adapters (lora/lokr/oft) use bypass injection (forward hook)
|
||||||
|
- Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification)
|
||||||
|
|
||||||
|
This is useful when the base model weights are quantized and cannot be
|
||||||
|
directly modified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mp: Model patcher
|
||||||
|
existing_weights: Dict of existing LoRA weights
|
||||||
|
algorithm: Algorithm name for new adapters
|
||||||
|
lora_dtype: dtype for LoRA weights
|
||||||
|
rank: Rank for new LoRA adapters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (lora_sd dict, all_weight_adapters list, bypass_manager)
|
||||||
|
"""
|
||||||
|
lora_sd = {}
|
||||||
|
all_weight_adapters = []
|
||||||
|
bypass_manager = BypassInjectionManager()
|
||||||
|
|
||||||
|
for n, m in mp.model.named_modules():
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
if m.weight is not None:
|
||||||
|
adapter, params = _create_weight_adapter(
|
||||||
|
m, n, existing_weights, algorithm, lora_dtype, rank
|
||||||
|
)
|
||||||
|
lora_sd.update(params)
|
||||||
|
all_weight_adapters.append(adapter)
|
||||||
|
|
||||||
|
key = f"{n}.weight"
|
||||||
|
# BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass
|
||||||
|
# Only use bypass for adapters that have h() method (lora/lokr/oft)
|
||||||
|
if isinstance(adapter, BiasDiff):
|
||||||
|
mp.add_weight_wrapper(key, adapter)
|
||||||
|
logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}")
|
||||||
|
else:
|
||||||
|
bypass_manager.add_adapter(key, adapter, strength=1.0)
|
||||||
|
logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}")
|
||||||
|
|
||||||
|
if hasattr(m, "bias") and m.bias is not None:
|
||||||
|
# Bias adapters still use weight wrapper (bias is usually not quantized)
|
||||||
|
bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype)
|
||||||
|
lora_sd.update(bias_params)
|
||||||
|
key = f"{n}.bias"
|
||||||
|
mp.add_weight_wrapper(key, bias_adapter)
|
||||||
|
all_weight_adapters.append(bias_adapter)
|
||||||
|
logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}")
|
||||||
|
|
||||||
|
return lora_sd, all_weight_adapters, bypass_manager
|
||||||
|
|
||||||
|
|
||||||
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
def _create_optimizer(optimizer_name, parameters, learning_rate):
|
||||||
"""Create optimizer based on name.
|
"""Create optimizer based on name.
|
||||||
|
|
||||||
@@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default=False,
|
default=False,
|
||||||
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"bypass_mode",
|
||||||
|
default=False,
|
||||||
|
tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(
|
|
||||||
display_name="model", tooltip="Model with LoRA applied"
|
|
||||||
),
|
|
||||||
io.Custom("LORA_MODEL").Output(
|
io.Custom("LORA_MODEL").Output(
|
||||||
display_name="lora", tooltip="LoRA weights"
|
display_name="lora", tooltip="LoRA weights"
|
||||||
),
|
),
|
||||||
@@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
bucket_mode,
|
bucket_mode,
|
||||||
|
bypass_mode,
|
||||||
):
|
):
|
||||||
# Extract scalars from lists (due to is_input_list=True)
|
# Extract scalars from lists (due to is_input_list=True)
|
||||||
model = model[0]
|
model = model[0]
|
||||||
@@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
|
bypass_mode = bypass_mode[0]
|
||||||
|
|
||||||
# Process latents based on mode
|
# Process latents based on mode
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
@@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||||
|
|
||||||
# Setup LoRA adapters
|
# Setup LoRA adapters
|
||||||
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
bypass_manager = None
|
||||||
mp, existing_weights, algorithm, lora_dtype, rank
|
if bypass_mode:
|
||||||
)
|
logging.debug("Using bypass mode for training")
|
||||||
|
lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass(
|
||||||
|
mp, existing_weights, algorithm, lora_dtype, rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_sd, all_weight_adapters = _setup_lora_adapters(
|
||||||
|
mp, existing_weights, algorithm, lora_dtype, rank
|
||||||
|
)
|
||||||
|
|
||||||
# Create optimizer and loss function
|
# Create optimizer and loss function
|
||||||
optimizer = _create_optimizer(
|
optimizer = _create_optimizer(
|
||||||
@@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
guider = TrainGuider(mp)
|
guider = TrainGuider(mp)
|
||||||
guider.set_conds(positive)
|
guider.set_conds(positive)
|
||||||
|
|
||||||
|
# Inject bypass hooks if bypass mode is enabled
|
||||||
|
bypass_injections = None
|
||||||
|
if bypass_manager is not None:
|
||||||
|
bypass_injections = bypass_manager.create_injections(mp.model)
|
||||||
|
for injection in bypass_injections:
|
||||||
|
injection.inject(mp)
|
||||||
|
logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks")
|
||||||
|
|
||||||
# Run training loop
|
# Run training loop
|
||||||
try:
|
try:
|
||||||
_run_training_loop(
|
_run_training_loop(
|
||||||
@@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
multi_res,
|
multi_res,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
# Eject bypass hooks if they were injected
|
||||||
|
if bypass_injections is not None:
|
||||||
|
for injection in bypass_injections:
|
||||||
|
injection.eject(mp)
|
||||||
|
logging.debug("[BypassMode] Ejected bypass hooks")
|
||||||
for m in mp.model.modules():
|
for m in mp.model.modules():
|
||||||
unpatch(m)
|
unpatch(m)
|
||||||
del train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
@@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
for param in lora_sd:
|
for param in lora_sd:
|
||||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||||
|
|
||||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
# mp in train node is highly specialized for training
|
||||||
|
# use it in inference will result in bad behavior so we don't return it
|
||||||
|
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
class LoraModelLoader(io.ComfyNode):#
|
class LoraModelLoader(io.ComfyNode):#
|
||||||
|
|||||||
67
nodes.py
67
nodes.py
@@ -722,6 +722,69 @@ class LoraLoaderModelOnly(LoraLoader):
|
|||||||
def load_lora_model_only(self, model, lora_name, strength_model):
|
def load_lora_model_only(self, model, lora_name, strength_model):
|
||||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||||
|
|
||||||
|
class LoraLoaderBypass:
|
||||||
|
"""
|
||||||
|
Apply LoRA in bypass mode without modifying base model weights.
|
||||||
|
|
||||||
|
Bypass mode computes: output = base_forward(x) + lora_path(x)
|
||||||
|
This is useful for training and when model weights are offloaded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.loaded_lora = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
||||||
|
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
|
||||||
|
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
||||||
|
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP")
|
||||||
|
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
|
||||||
|
FUNCTION = "load_lora"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
|
||||||
|
|
||||||
|
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||||
|
if strength_model == 0 and strength_clip == 0:
|
||||||
|
return (model, clip)
|
||||||
|
|
||||||
|
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||||
|
lora = None
|
||||||
|
if self.loaded_lora is not None:
|
||||||
|
if self.loaded_lora[0] == lora_path:
|
||||||
|
lora = self.loaded_lora[1]
|
||||||
|
else:
|
||||||
|
self.loaded_lora = None
|
||||||
|
|
||||||
|
if lora is None:
|
||||||
|
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||||
|
self.loaded_lora = (lora_path, lora)
|
||||||
|
|
||||||
|
model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||||
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLoaderBypassModelOnly(LoraLoaderBypass):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "load_lora_model_only"
|
||||||
|
|
||||||
|
def load_lora_model_only(self, model, lora_name, strength_model):
|
||||||
|
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
||||||
@@ -2067,6 +2130,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentFlip": LatentFlip,
|
"LatentFlip": LatentFlip,
|
||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
"LoraLoader": LoraLoader,
|
"LoraLoader": LoraLoader,
|
||||||
|
"LoraLoaderBypass": LoraLoaderBypass,
|
||||||
|
"LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
|
||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
"UNETLoader": UNETLoader,
|
"UNETLoader": UNETLoader,
|
||||||
"DualCLIPLoader": DualCLIPLoader,
|
"DualCLIPLoader": DualCLIPLoader,
|
||||||
@@ -2106,6 +2171,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||||
"VAELoader": "Load VAE",
|
"VAELoader": "Load VAE",
|
||||||
"LoraLoader": "Load LoRA",
|
"LoraLoader": "Load LoRA",
|
||||||
|
"LoraLoaderBypass": "Load LoRA (Bypass)",
|
||||||
|
"LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only)",
|
||||||
"CLIPLoader": "Load CLIP",
|
"CLIPLoader": "Load CLIP",
|
||||||
"ControlNetLoader": "Load ControlNet Model",
|
"ControlNetLoader": "Load ControlNet Model",
|
||||||
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
||||||
|
|||||||
Reference in New Issue
Block a user