Various bug fixes and optimizations for quantized training. Added untested custom adam8bit optimizer. Did some work on LoRM (dont use)

This commit is contained in:
Jaret Burkett
2024-11-20 09:16:55 -07:00
parent 6509ba4484
commit 894374b2e9
7 changed files with 241 additions and 18 deletions

View File

@@ -15,6 +15,7 @@ from toolkit.lorm import extract_conv, extract_linear, count_parameters
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
from toolkit.saving import get_lora_keymap_from_model_keymap
from optimum.quanto import QBytesTensor
if TYPE_CHECKING:
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
@@ -27,7 +28,8 @@ Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule']
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
'LoRACompatibleLinear',
'QLinear'
# 'GroupNorm',
]
CONV_MODULES = [
@@ -108,11 +110,16 @@ class ExtractableModuleMixin:
if extract_mode == "existing":
extract_mode = 'fixed'
extract_mode_param = self.lora_dim
if isinstance(weight_to_extract, QBytesTensor):
weight_to_extract = weight_to_extract.dequantize()
weight_to_extract = weight_to_extract.clone().detach().float()
if self.org_module[0].__class__.__name__ in CONV_MODULES:
# do conv extraction
down_weight, up_weight, new_dim, diff = extract_conv(
weight=weight_to_extract.clone().detach().float(),
weight=weight_to_extract,
mode=extract_mode,
mode_param=extract_mode_param,
device=device
@@ -121,7 +128,7 @@ class ExtractableModuleMixin:
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
# do linear extraction
down_weight, up_weight, new_dim, diff = extract_linear(
weight=weight_to_extract.clone().detach().float(),
weight=weight_to_extract,
mode=extract_mode,
mode_param=extract_mode_param,
device=device,
@@ -210,6 +217,11 @@ class ToolkitModuleMixin:
network: Network = self.network_ref()
if not network.is_active:
return self.org_forward(x, *args, **kwargs)
orig_dtype = x.dtype
if x.dtype != self.lora_down.weight.dtype:
x = x.to(self.lora_down.weight.dtype)
if network.lorm_train_mode == 'local':
# we are going to predict input with both and do a loss on them
@@ -230,7 +242,9 @@ class ToolkitModuleMixin:
return target_pred
else:
return self.lora_up(self.lora_down(x))
x = self.lora_up(self.lora_down(x))
if x.dtype != orig_dtype:
x = x.to(orig_dtype)
def forward(self: Module, x, *args, **kwargs):
skip = False