Added support for traditional LoRa extract using LoCon script

This commit is contained in:
Jaret Burkett
2023-07-12 19:51:40 -06:00
parent 57f14e5ef2
commit 8d6edae9fd
9 changed files with 111 additions and 8 deletions

View File

@@ -8,6 +8,8 @@ from toolkit.metadata import get_meta_for_safetensors
from typing import ForwardRef
from toolkit.train_tools import get_torch_dtype
class BaseExtractProcess(BaseProcess):
process_id: int
@@ -26,6 +28,8 @@ class BaseExtractProcess(BaseProcess):
self.process_id = process_id
self.job = job
self.config = config
self.dtype = self.get_conf('dtype', self.job.dtype)
self.torch_dtype = get_torch_dtype(self.dtype)
def run(self):
# here instead of init because child init needs to go first
@@ -70,6 +74,11 @@ class BaseExtractProcess(BaseProcess):
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(self.torch_dtype)
state_dict[key] = v
# having issues with meta
save_file(state_dict, self.output_path, save_meta)

View File

@@ -21,13 +21,13 @@ class BaseProcess:
def get_conf(self, key, default=None, required=False, as_type=None):
if key in self.config:
value = self.config[key]
if as_type is not None:
if as_type is not None and value is not None:
value = as_type(value)
return value
elif required:
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
else:
if as_type is not None:
if as_type is not None and default is not None:
return as_type(default)
return default

View File

@@ -37,8 +37,8 @@ class ExtractLoconProcess(BaseExtractProcess):
# set modes
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']:
raise ValueError(f"Unknown mode: {self.mode}")
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type'])
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type'])
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type'])
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type'])
def run(self):
super().run()

View File

@@ -0,0 +1,52 @@
from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess
mode_dict = {
'fixed': {
'dim': 4,
'type': int
}
}
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6
class ExtractLoraProcess(BaseExtractProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')
# set modes
if self.mode not in ['fixed']:
raise ValueError(f"Unknown mode: {self.mode}")
self.dim = self.get_conf('dim', mode_dict[self.mode]['dim'], as_type=mode_dict[self.mode]['type'])
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
self.sparsity = self.get_conf('sparsity', 0.98)
def run(self):
super().run()
print(f"Running process: {self.mode}, dim: {self.dim}")
state_dict, extract_diff_meta = extract_diff(
self.job.base_model,
self.job.extract_model,
self.mode,
self.dim,
0,
self.job.device,
self.use_sparse_bias,
self.sparsity,
small_conv=False,
linear_only=True,
)
self.add_meta(extract_diff_meta)
self.save(state_dict)
def get_output_path(self, prefix=None, suffix=None):
if suffix is None:
suffix = f"_{self.dim}"
return super().get_output_path(prefix, suffix)

View File

@@ -1,5 +1,6 @@
from .BaseExtractProcess import BaseExtractProcess
from .ExtractLoconProcess import ExtractLoconProcess
from .ExtractLoraProcess import ExtractLoraProcess
from .BaseProcess import BaseProcess
from .BaseTrainProcess import BaseTrainProcess
from .TrainFineTuneProcess import TrainFineTuneProcess