mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 01:39:20 +00:00
Added support for traditional LoRa extract using LoCon script
This commit is contained in:
52
jobs/process/ExtractLoraProcess.py
Normal file
52
jobs/process/ExtractLoraProcess.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from collections import OrderedDict
|
||||
from toolkit.lycoris_utils import extract_diff
|
||||
from .BaseExtractProcess import BaseExtractProcess
|
||||
|
||||
mode_dict = {
|
||||
'fixed': {
|
||||
'dim': 4,
|
||||
'type': int
|
||||
}
|
||||
}
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
MIN_DIFF = 1e-6
|
||||
|
||||
|
||||
class ExtractLoraProcess(BaseExtractProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.mode = self.get_conf('mode', 'fixed')
|
||||
|
||||
# set modes
|
||||
if self.mode not in ['fixed']:
|
||||
raise ValueError(f"Unknown mode: {self.mode}")
|
||||
self.dim = self.get_conf('dim', mode_dict[self.mode]['dim'], as_type=mode_dict[self.mode]['type'])
|
||||
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
|
||||
self.sparsity = self.get_conf('sparsity', 0.98)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print(f"Running process: {self.mode}, dim: {self.dim}")
|
||||
|
||||
state_dict, extract_diff_meta = extract_diff(
|
||||
self.job.base_model,
|
||||
self.job.extract_model,
|
||||
self.mode,
|
||||
self.dim,
|
||||
0,
|
||||
self.job.device,
|
||||
self.use_sparse_bias,
|
||||
self.sparsity,
|
||||
small_conv=False,
|
||||
linear_only=True,
|
||||
)
|
||||
|
||||
self.add_meta(extract_diff_meta)
|
||||
self.save(state_dict)
|
||||
|
||||
def get_output_path(self, prefix=None, suffix=None):
|
||||
if suffix is None:
|
||||
suffix = f"_{self.dim}"
|
||||
return super().get_output_path(prefix, suffix)
|
||||
Reference in New Issue
Block a user