Reworked so everything is in classes for easy expansion. Single entry point for all config files now.

This commit is contained in:
Jaret Burkett
2023-07-08 09:51:42 -06:00
parent 27df03a486
commit 37354b006e
16 changed files with 424 additions and 189 deletions

View File

@@ -0,0 +1,76 @@
import os
from collections import OrderedDict
from safetensors.torch import save_file
from jobs import ExtractJob
from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors
class BaseExtractProcess(BaseProcess):
job: ExtractJob
process_id: int
config: OrderedDict
output_folder: str
output_filename: str
output_path: str
def __init__(
self,
process_id: int,
job: ExtractJob,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.process_id = process_id
self.job = job
self.config = config
def run(self):
# here instead of init because child init needs to go first
self.output_path = self.get_output_path()
# implement in child class
# be sure to call super().run() first
pass
# you can override this in the child class if you want
# call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this
def get_output_path(self, prefix=None, suffix=None):
config_output_path = self.get_conf('output_path', None)
config_filename = self.get_conf('filename', None)
# replace [name] with name
if config_output_path is not None:
config_output_path = config_output_path.replace('[name]', self.job.name)
return config_output_path
if config_output_path is None and config_filename is not None:
# build the output path from the output folder and filename
return os.path.join(self.job.output_folder, config_filename)
# build our own
if suffix is None:
# we will just add process it to the end of the filename if there is more than one process
# and no other suffix was given
suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else ''
if prefix is None:
prefix = ''
output_filename = f"{prefix}{self.output_filename}{suffix}"
return os.path.join(self.job.output_folder, output_filename)
def save(self, state_dict):
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
# having issues with meta
save_file(state_dict, self.output_path, save_meta)
print(f"Saved to {self.output_path}")

View File

@@ -0,0 +1,42 @@
import copy
import json
from collections import OrderedDict
from jobs import BaseJob
class BaseProcess:
meta: OrderedDict
def __init__(
self,
process_id: int,
job: BaseJob,
config: OrderedDict
):
self.process_id = process_id
self.job = job
self.config = config
self.meta = copy.deepcopy(self.job.meta)
def get_conf(self, key, default=None, required=False, as_type=None):
if key in self.config:
value = self.config[key]
if as_type is not None:
value = as_type(value)
return value
elif required:
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
else:
if as_type is not None:
return as_type(default)
return default
def run(self):
# implement in child class
# be sure to call super().run() first incase something is added here
pass
def add_meta(self, additional_meta: OrderedDict):
self.meta.update(additional_meta)

View File

@@ -0,0 +1,67 @@
from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess
from .. import ExtractJob
mode_dict = {
'fixed': {
'linear': 64,
'conv': 32,
'type': int
},
'threshold': {
'linear': 0,
'conv': 0,
'type': float
},
'ratio': {
'linear': 0.5,
'conv': 0.5,
'type': float
},
'quantile': {
'linear': 0.5,
'conv': 0.5,
'type': float
}
}
class LoconExtractProcess(BaseExtractProcess):
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
self.sparsity = self.get_conf('sparsity', 0.98)
self.disable_cp = self.get_conf('disable_cp', False)
# set modes
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']:
raise ValueError(f"Unknown mode: {self.mode}")
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type'])
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type'])
def run(self):
super().run()
print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}")
state_dict, extract_diff_meta = extract_diff(
self.job.base_model,
self.job.extract_model,
self.mode,
self.linear_param,
self.conv_param,
self.job.device,
self.use_sparse_bias,
self.sparsity,
not self.disable_cp
)
self.add_meta(extract_diff_meta)
self.save(state_dict)
def get_output_path(self, prefix=None, suffix=None):
if suffix is None:
suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
return super().get_output_path(prefix, suffix)

3
jobs/process/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .BaseExtractProcess import BaseExtractProcess
from .LoconExtractProcess import LoconExtractProcess
from .BaseProcess import BaseProcess