mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added extensions and an example extension that merges models
This commit is contained in:
@@ -60,7 +60,11 @@ class BaseJob:
|
||||
|
||||
# check if dict key is process type
|
||||
if process['type'] in process_dict:
|
||||
ProcessClass = getattr(module, process_dict[process['type']])
|
||||
if isinstance(process_dict[process['type']], str):
|
||||
ProcessClass = getattr(module, process_dict[process['type']])
|
||||
else:
|
||||
# it is the class
|
||||
ProcessClass = process_dict[process['type']]
|
||||
self.process.append(ProcessClass(i, self, process))
|
||||
else:
|
||||
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
|
||||
|
||||
21
jobs/ExtensionJob.py
Normal file
21
jobs/ExtensionJob.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from collections import OrderedDict
|
||||
from jobs import BaseJob
|
||||
from toolkit.extension import get_all_extensions_process_dict
|
||||
|
||||
|
||||
class ExtensionJob(BaseJob):
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
super().__init__(config)
|
||||
self.device = self.get_conf('device', 'cpu')
|
||||
self.process_dict = get_all_extensions_process_dict()
|
||||
self.load_processes(self.process_dict)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
print("")
|
||||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||
|
||||
for process in self.process:
|
||||
process.run()
|
||||
@@ -4,3 +4,4 @@ from .TrainJob import TrainJob
|
||||
from .MergeJob import MergeJob
|
||||
from .ModJob import ModJob
|
||||
from .GenerateJob import GenerateJob
|
||||
from .ExtensionJob import ExtensionJob
|
||||
|
||||
20
jobs/process/BaseExtensionProcess.py
Normal file
20
jobs/process/BaseExtensionProcess.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from collections import OrderedDict
|
||||
from typing import ForwardRef
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
|
||||
|
||||
class BaseExtensionProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
progress_bar: ForwardRef('tqdm') = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job,
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -11,3 +11,4 @@ from .TrainLoRAHack import TrainLoRAHack
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
from .GenerateProcess import GenerateProcess
|
||||
from .BaseExtensionProcess import BaseExtensionProcess
|
||||
|
||||
Reference in New Issue
Block a user