mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
WIP - adding support for flux DoRA and ip adapter training
This commit is contained in:
@@ -6,6 +6,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
|
||||
from optimum.quanto import QBytesTensor, QTensor
|
||||
|
||||
from toolkit.network_mixins import ToolkitModuleMixin, ExtractableModuleMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -89,6 +91,7 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
|
||||
# m = Magnitude column-wise across output dimension
|
||||
weight = self.get_orig_weight()
|
||||
weight = weight.to(self.lora_up.weight.device, dtype=self.lora_up.weight.dtype)
|
||||
lora_weight = self.lora_up.weight @ self.lora_down.weight
|
||||
weight_norm = self._get_weight_norm(weight, lora_weight)
|
||||
self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True)
|
||||
@@ -99,7 +102,11 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
# del self.org_module
|
||||
|
||||
def get_orig_weight(self):
|
||||
return self.org_module[0].weight.data.detach()
|
||||
weight = self.org_module[0].weight
|
||||
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
|
||||
return weight.dequantize().data.detach()
|
||||
else:
|
||||
return weight.data.detach()
|
||||
|
||||
def get_orig_bias(self):
|
||||
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
||||
@@ -126,6 +133,7 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
|
||||
# magnitude = self.lora_magnitude_vector[active_adapter]
|
||||
weight = self.get_orig_weight()
|
||||
weight = weight.to(scaled_lora_weight.device, dtype=scaled_lora_weight.dtype)
|
||||
weight_norm = self._get_weight_norm(weight, scaled_lora_weight)
|
||||
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
|
||||
# "[...] we suggest treating ||V +∆V ||_c in
|
||||
@@ -135,4 +143,4 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
# during backpropagation"
|
||||
weight_norm = weight_norm.detach()
|
||||
dora_weight = transpose(weight + scaled_lora_weight, False)
|
||||
return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x, dora_weight)
|
||||
return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x.to(dora_weight.dtype), dora_weight)
|
||||
|
||||
Reference in New Issue
Block a user