mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Reworked so everything is in classes for easy expansion. Single entry point for all config files now.
This commit is contained in:
76
jobs/process/BaseExtractProcess.py
Normal file
76
jobs/process/BaseExtractProcess.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from jobs import ExtractJob
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
|
||||
|
||||
class BaseExtractProcess(BaseProcess):
|
||||
job: ExtractJob
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
output_folder: str
|
||||
output_filename: str
|
||||
output_path: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job: ExtractJob,
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
|
||||
def run(self):
|
||||
# here instead of init because child init needs to go first
|
||||
self.output_path = self.get_output_path()
|
||||
# implement in child class
|
||||
# be sure to call super().run() first
|
||||
pass
|
||||
|
||||
# you can override this in the child class if you want
|
||||
# call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this
|
||||
def get_output_path(self, prefix=None, suffix=None):
|
||||
config_output_path = self.get_conf('output_path', None)
|
||||
config_filename = self.get_conf('filename', None)
|
||||
# replace [name] with name
|
||||
|
||||
if config_output_path is not None:
|
||||
config_output_path = config_output_path.replace('[name]', self.job.name)
|
||||
return config_output_path
|
||||
|
||||
if config_output_path is None and config_filename is not None:
|
||||
# build the output path from the output folder and filename
|
||||
return os.path.join(self.job.output_folder, config_filename)
|
||||
|
||||
# build our own
|
||||
|
||||
if suffix is None:
|
||||
# we will just add process it to the end of the filename if there is more than one process
|
||||
# and no other suffix was given
|
||||
suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else ''
|
||||
|
||||
if prefix is None:
|
||||
prefix = ''
|
||||
|
||||
output_filename = f"{prefix}{self.output_filename}{suffix}"
|
||||
|
||||
return os.path.join(self.job.output_folder, output_filename)
|
||||
|
||||
def save(self, state_dict):
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
|
||||
# save
|
||||
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
|
||||
|
||||
# having issues with meta
|
||||
save_file(state_dict, self.output_path, save_meta)
|
||||
|
||||
print(f"Saved to {self.output_path}")
|
||||
42
jobs/process/BaseProcess.py
Normal file
42
jobs/process/BaseProcess.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import copy
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from jobs import BaseJob
|
||||
|
||||
|
||||
class BaseProcess:
|
||||
meta: OrderedDict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job: BaseJob,
|
||||
config: OrderedDict
|
||||
):
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
self.meta = copy.deepcopy(self.job.meta)
|
||||
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
if key in self.config:
|
||||
value = self.config[key]
|
||||
if as_type is not None:
|
||||
value = as_type(value)
|
||||
return value
|
||||
elif required:
|
||||
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
|
||||
else:
|
||||
if as_type is not None:
|
||||
return as_type(default)
|
||||
return default
|
||||
|
||||
def run(self):
|
||||
# implement in child class
|
||||
# be sure to call super().run() first incase something is added here
|
||||
pass
|
||||
|
||||
def add_meta(self, additional_meta: OrderedDict):
|
||||
self.meta.update(additional_meta)
|
||||
|
||||
67
jobs/process/LoconExtractProcess.py
Normal file
67
jobs/process/LoconExtractProcess.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from collections import OrderedDict
|
||||
from toolkit.lycoris_utils import extract_diff
|
||||
from .BaseExtractProcess import BaseExtractProcess
|
||||
from .. import ExtractJob
|
||||
|
||||
mode_dict = {
|
||||
'fixed': {
|
||||
'linear': 64,
|
||||
'conv': 32,
|
||||
'type': int
|
||||
},
|
||||
'threshold': {
|
||||
'linear': 0,
|
||||
'conv': 0,
|
||||
'type': float
|
||||
},
|
||||
'ratio': {
|
||||
'linear': 0.5,
|
||||
'conv': 0.5,
|
||||
'type': float
|
||||
},
|
||||
'quantile': {
|
||||
'linear': 0.5,
|
||||
'conv': 0.5,
|
||||
'type': float
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LoconExtractProcess(BaseExtractProcess):
|
||||
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.mode = self.get_conf('mode', 'fixed')
|
||||
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
|
||||
self.sparsity = self.get_conf('sparsity', 0.98)
|
||||
self.disable_cp = self.get_conf('disable_cp', False)
|
||||
|
||||
# set modes
|
||||
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']:
|
||||
raise ValueError(f"Unknown mode: {self.mode}")
|
||||
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type'])
|
||||
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type'])
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}")
|
||||
|
||||
state_dict, extract_diff_meta = extract_diff(
|
||||
self.job.base_model,
|
||||
self.job.extract_model,
|
||||
self.mode,
|
||||
self.linear_param,
|
||||
self.conv_param,
|
||||
self.job.device,
|
||||
self.use_sparse_bias,
|
||||
self.sparsity,
|
||||
not self.disable_cp
|
||||
)
|
||||
|
||||
self.add_meta(extract_diff_meta)
|
||||
self.save(state_dict)
|
||||
|
||||
def get_output_path(self, prefix=None, suffix=None):
|
||||
if suffix is None:
|
||||
suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
|
||||
return super().get_output_path(prefix, suffix)
|
||||
|
||||
3
jobs/process/__init__.py
Normal file
3
jobs/process/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .BaseExtractProcess import BaseExtractProcess
|
||||
from .LoconExtractProcess import LoconExtractProcess
|
||||
from .BaseProcess import BaseProcess
|
||||
Reference in New Issue
Block a user