Added extensions and an example extension that merges models

This commit is contained in:
Jaret Burkett
2023-08-04 09:37:24 -06:00
parent b865ac8b24
commit 7e4e660663
14 changed files with 366 additions and 24 deletions

View File

@@ -60,7 +60,11 @@ class BaseJob:
# check if dict key is process type
if process['type'] in process_dict:
ProcessClass = getattr(module, process_dict[process['type']])
if isinstance(process_dict[process['type']], str):
ProcessClass = getattr(module, process_dict[process['type']])
else:
# it is the class
ProcessClass = process_dict[process['type']]
self.process.append(ProcessClass(i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')

21
jobs/ExtensionJob.py Normal file
View File

@@ -0,0 +1,21 @@
from collections import OrderedDict
from jobs import BaseJob
from toolkit.extension import get_all_extensions_process_dict
class ExtensionJob(BaseJob):
def __init__(self, config: OrderedDict):
super().__init__(config)
self.device = self.get_conf('device', 'cpu')
self.process_dict = get_all_extensions_process_dict()
self.load_processes(self.process_dict)
def run(self):
super().run()
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
for process in self.process:
process.run()

View File

@@ -4,3 +4,4 @@ from .TrainJob import TrainJob
from .MergeJob import MergeJob
from .ModJob import ModJob
from .GenerateJob import GenerateJob
from .ExtensionJob import ExtensionJob

View File

@@ -0,0 +1,20 @@
from collections import OrderedDict
from typing import ForwardRef
from jobs.process.BaseProcess import BaseProcess
class BaseExtensionProcess(BaseProcess):
process_id: int
config: OrderedDict
progress_bar: ForwardRef('tqdm') = None
def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
def run(self):
super().run()

View File

@@ -11,3 +11,4 @@ from .TrainLoRAHack import TrainLoRAHack
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .ModRescaleLoraProcess import ModRescaleLoraProcess
from .GenerateProcess import GenerateProcess
from .BaseExtensionProcess import BaseExtensionProcess