import logging from typing import Callable, Optional import torch import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, weight_decompose class GLoRAAdapter(WeightAdapterBase): name = "glora" def __init__(self, loaded_keys, weights): self.loaded_keys = loaded_keys self.weights = weights @classmethod def load( cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, ) -> Optional["GLoRAAdapter"]: if loaded_keys is None: loaded_keys = set() a1_name = "{}.a1.weight".format(x) a2_name = "{}.a2.weight".format(x) b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: weights = ( lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale, ) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) loaded_keys.add(b2_name) return cls(loaded_keys, weights) else: return None def calculate_weight( self, weight, key, strength, strength_model, offset, function, intermediate_dtype=torch.float32, original_weight=None, ): v = self.weights dora_scale = v[5] old_glora = False if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: rank = v[0].shape[0] old_glora = True if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: if ( old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1] ): pass else: old_glora = False rank = v[1].shape[0] a1 = comfy.model_management.cast_to_device( v[0].flatten(start_dim=1), weight.device, intermediate_dtype ) a2 = comfy.model_management.cast_to_device( v[1].flatten(start_dim=1), weight.device, intermediate_dtype ) b1 = comfy.model_management.cast_to_device( v[2].flatten(start_dim=1), weight.device, intermediate_dtype ) b2 = comfy.model_management.cast_to_device( v[3].flatten(start_dim=1), weight.device, intermediate_dtype ) if v[4] is not None: alpha = v[4] / rank else: alpha = 1.0 try: if old_glora: lora_diff = ( torch.mm(b2, b1) + torch.mm( torch.mm( weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2 ), a1, ) ).reshape( weight.shape ) # old lycoris glora else: if weight.dim() > 2: lora_diff = torch.einsum( "o i ..., i j -> o j ...", torch.einsum( "o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1, ), a2, ).reshape(weight.shape) else: lora_diff = torch.mm( torch.mm(weight.to(dtype=intermediate_dtype), a1), a2 ).reshape(weight.shape) lora_diff += torch.mm(b1, b2).reshape(weight.shape) if dora_scale is not None: weight = weight_decompose( dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function, ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight def _compute_paths(self, x: torch.Tensor): """ Compute A path and B path outputs for GLoRA bypass. GLoRA: f(x) = Wx + WAx + Bx - A path: a1(a2(x)) - modifies input to base forward - B path: b1(b2(x)) - additive component Note: Does not access original model weights - bypass mode is designed for quantized models where weights may not be accessible. Returns: (a_out, b_out) """ v = self.weights # v = (a1, a2, b1, b2, alpha, dora_scale) a1 = v[0] a2 = v[1] b1 = v[2] b2 = v[3] alpha = v[4] dtype = x.dtype # Cast dtype (weights should already be on correct device from inject()) a1 = a1.to(dtype=dtype) a2 = a2.to(dtype=dtype) b1 = b1.to(dtype=dtype) b2 = b2.to(dtype=dtype) # Determine rank and scale # Check for old vs new glora format old_glora = False if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]: rank = a1.shape[0] old_glora = True if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]: if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]: pass else: old_glora = False rank = a2.shape[0] if alpha is not None: scale = alpha / rank else: scale = 1.0 # Apply multiplier multiplier = getattr(self, "multiplier", 1.0) scale = scale * multiplier # Use module info from bypass injection, not input tensor shape is_conv = getattr(self, "is_conv", False) conv_dim = getattr(self, "conv_dim", 0) kw_dict = getattr(self, "kw_dict", {}) if is_conv: # Conv case - conv_dim is 1/2/3 for conv1d/2d/3d conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] # Get module's stride/padding for spatial dimension handling module_stride = kw_dict.get("stride", (1,) * conv_dim) module_padding = kw_dict.get("padding", (0,) * conv_dim) kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) in_channels = getattr(self, "in_channels", None) # Ensure weights are in conv shape # a1, a2, b1 are always 1x1 kernels if a1.ndim == 2: a1 = a1.view(*a1.shape, *([1] * conv_dim)) if a2.ndim == 2: a2 = a2.view(*a2.shape, *([1] * conv_dim)) if b1.ndim == 2: b1 = b1.view(*b1.shape, *([1] * conv_dim)) # b2 has actual kernel_size (like LoRA down) if b2.ndim == 2: if in_channels is not None: b2 = b2.view(b2.shape[0], in_channels, *kernel_size) else: b2 = b2.view(*b2.shape, *([1] * conv_dim)) # A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x a2_out = conv_fn(x, a2) a_out = conv_fn(a2_out, a1) * scale # B path: b2(x) with kernel/stride/padding -> b1(...) 1x1 b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding) b_out = conv_fn(b2_out, b1) * scale else: # Linear case if old_glora: # Old format: a1 @ a2 @ x, b2 @ b1 a_out = F.linear(F.linear(x, a2), a1) * scale b_out = F.linear(F.linear(x, b1), b2) * scale else: # New format: x @ a1 @ a2, b1 @ b2 a_out = F.linear(F.linear(x, a1), a2) * scale b_out = F.linear(F.linear(x, b2), b1) * scale return a_out, b_out def bypass_forward( self, org_forward: Callable, x: torch.Tensor, *args, **kwargs, ) -> torch.Tensor: """ GLoRA bypass forward: f(x + a(x)) + b(x) Unlike standard adapters, GLoRA modifies the input to the base forward AND adds the B path output. Note: Does not access original model weights - bypass mode is designed for quantized models where weights may not be accessible. Reference: LyCORIS GLoRAModule._bypass_forward """ a_out, b_out = self._compute_paths(x) # Call base forward with modified input base_out = org_forward(x + a_out, *args, **kwargs) # Add B path return base_out + b_out def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: """ For GLoRA, h() returns the B path output. Note: GLoRA's full bypass requires overriding bypass_forward() since it also modifies the input to org_forward. This h() is provided for compatibility but bypass_forward() should be used for correct behavior. Does not access original model weights - bypass mode is designed for quantized models where weights may not be accessible. Args: x: Input tensor base_out: Output from base forward (unused, for API consistency) """ _, b_out = self._compute_paths(x) return b_out