Fixed Dora implementation. Still highly experimental

This commit is contained in:
Jaret Burkett
2024-02-24 10:26:01 -07:00
parent 1bd94f0f01
commit f965a1299f
5 changed files with 128 additions and 33 deletions

View File

@@ -22,6 +22,13 @@ CONV_MODULES = [
'LoRACompatibleConv'
]
def transpose(weight, fan_in_fan_out):
if not fan_in_fan_out:
return weight
if isinstance(weight, torch.nn.Parameter):
return torch.nn.Parameter(weight.T)
return weight.T
class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
# def __init__(self, d_in, d_out, rank=4, weight=None, bias=None):
@@ -65,15 +72,26 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
self.module_dropout = module_dropout
self.is_checkpointing = False
# m = Magnitude column-wise across output dimension
self.magnitude = nn.Parameter(self.get_orig_weight().norm(p=2, dim=0, keepdim=True))
d_out = org_module.out_features
d_in = org_module.in_features
std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float())
self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev)
self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in))
# self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev) # lora_A
# self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in)) # lora_B
self.lora_up = nn.Linear(self.lora_dim, d_out, bias=False) # lora_B
# self.lora_up.weight.data = torch.randn_like(self.lora_up.weight.data) * std_dev
self.lora_up.weight.data = torch.zeros_like(self.lora_up.weight.data)
# self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
# self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
self.lora_down = nn.Linear(d_in, self.lora_dim, bias=False) # lora_A
# self.lora_down.weight.data = torch.zeros_like(self.lora_down.weight.data)
self.lora_down.weight.data = torch.randn_like(self.lora_down.weight.data) * std_dev
# m = Magnitude column-wise across output dimension
weight = self.get_orig_weight()
lora_weight = self.lora_up.weight @ self.lora_down.weight
weight_norm = self._get_weight_norm(weight, lora_weight)
self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True)
def apply_to(self):
self.org_forward = self.org_module[0].forward
@@ -88,11 +106,33 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
return self.org_module[0].bias.data.detach()
return None
def dora_forward(self, x, *args, **kwargs):
lora = torch.matmul(self.lora_up, self.lora_down)
adapted = self.get_orig_weight() + lora
column_norm = adapted.norm(p=2, dim=0, keepdim=True)
norm_adapted = adapted / column_norm
calc_weights = self.magnitude * norm_adapted
return F.linear(x, calc_weights, self.get_orig_bias())
# def dora_forward(self, x, *args, **kwargs):
# lora = torch.matmul(self.lora_A, self.lora_B)
# adapted = self.get_orig_weight() + lora
# column_norm = adapted.norm(p=2, dim=0, keepdim=True)
# norm_adapted = adapted / column_norm
# calc_weights = self.magnitude * norm_adapted
# return F.linear(x, calc_weights, self.get_orig_bias())
def _get_weight_norm(self, weight, scaled_lora_weight) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = weight + scaled_lora_weight.to(weight.device)
weight_norm = torch.linalg.norm(weight, dim=1)
return weight_norm
def apply_dora(self, x, scaled_lora_weight):
# ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L192
# lora weight is already scaled
# magnitude = self.lora_magnitude_vector[active_adapter]
weight = self.get_orig_weight()
weight_norm = self._get_weight_norm(weight, scaled_lora_weight)
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
# "[...] we suggest treating ||V +∆V ||_c in
# Eq. (5) as a constant, thereby detaching it from the gradient
# graph. This means that while ||V + ∆V ||_c dynamically
# reflects the updates of ∆V , it wont receive any gradient
# during backpropagation"
weight_norm = weight_norm.detach()
dora_weight = transpose(weight + scaled_lora_weight, False)
return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x, dora_weight)