Files
ai-toolkit/jobs/process/ExtractLoraProcess.py

53 lines
1.5 KiB
Python

from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess
mode_dict = {
'fixed': {
'dim': 4,
'type': int
}
}
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6
class ExtractLoraProcess(BaseExtractProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')
# set modes
if self.mode not in ['fixed']:
raise ValueError(f"Unknown mode: {self.mode}")
self.dim = self.get_conf('dim', mode_dict[self.mode]['dim'], as_type=mode_dict[self.mode]['type'])
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
self.sparsity = self.get_conf('sparsity', 0.98)
def run(self):
super().run()
print(f"Running process: {self.mode}, dim: {self.dim}")
state_dict, extract_diff_meta = extract_diff(
self.job.base_model,
self.job.extract_model,
self.mode,
self.dim,
0,
self.job.device,
self.use_sparse_bias,
self.sparsity,
small_conv=False,
linear_only=True,
)
self.add_meta(extract_diff_meta)
self.save(state_dict)
def get_output_path(self, prefix=None, suffix=None):
if suffix is None:
suffix = f"_{self.dim}"
return super().get_output_path(prefix, suffix)