mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Various bug fixes and optimizations for quantized training. Added untested custom adam8bit optimizer. Did some work on LoRM (dont use)
This commit is contained in:
@@ -15,6 +15,7 @@ from toolkit.lorm import extract_conv, extract_linear, count_parameters
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
from toolkit.saving import get_lora_keymap_from_model_keymap
|
||||
from optimum.quanto import QBytesTensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||
@@ -27,7 +28,8 @@ Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule']
|
||||
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
'LoRACompatibleLinear',
|
||||
'QLinear'
|
||||
# 'GroupNorm',
|
||||
]
|
||||
CONV_MODULES = [
|
||||
@@ -108,11 +110,16 @@ class ExtractableModuleMixin:
|
||||
if extract_mode == "existing":
|
||||
extract_mode = 'fixed'
|
||||
extract_mode_param = self.lora_dim
|
||||
|
||||
if isinstance(weight_to_extract, QBytesTensor):
|
||||
weight_to_extract = weight_to_extract.dequantize()
|
||||
|
||||
weight_to_extract = weight_to_extract.clone().detach().float()
|
||||
|
||||
if self.org_module[0].__class__.__name__ in CONV_MODULES:
|
||||
# do conv extraction
|
||||
down_weight, up_weight, new_dim, diff = extract_conv(
|
||||
weight=weight_to_extract.clone().detach().float(),
|
||||
weight=weight_to_extract,
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=device
|
||||
@@ -121,7 +128,7 @@ class ExtractableModuleMixin:
|
||||
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
|
||||
# do linear extraction
|
||||
down_weight, up_weight, new_dim, diff = extract_linear(
|
||||
weight=weight_to_extract.clone().detach().float(),
|
||||
weight=weight_to_extract,
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=device,
|
||||
@@ -210,6 +217,11 @@ class ToolkitModuleMixin:
|
||||
network: Network = self.network_ref()
|
||||
if not network.is_active:
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
orig_dtype = x.dtype
|
||||
|
||||
if x.dtype != self.lora_down.weight.dtype:
|
||||
x = x.to(self.lora_down.weight.dtype)
|
||||
|
||||
if network.lorm_train_mode == 'local':
|
||||
# we are going to predict input with both and do a loss on them
|
||||
@@ -230,7 +242,9 @@ class ToolkitModuleMixin:
|
||||
return target_pred
|
||||
|
||||
else:
|
||||
return self.lora_up(self.lora_down(x))
|
||||
x = self.lora_up(self.lora_down(x))
|
||||
if x.dtype != orig_dtype:
|
||||
x = x.to(orig_dtype)
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
skip = False
|
||||
|
||||
Reference in New Issue
Block a user