mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Setup base for training jobs. Added sd-scripts as a submodule
This commit is contained in:
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "repositories/sd-scripts"]
|
||||
path = repositories/sd-scripts
|
||||
url = https://github.com/kohya-ss/sd-scripts.git
|
||||
32
config/examples/train.example.json
Normal file
32
config/examples/train.example.json
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"job": "train",
|
||||
"config": {
|
||||
"name": "name_of_your_model",
|
||||
"base_model": "/path/to/base/model",
|
||||
"training_folder": "/path/to/output/folder",
|
||||
"is_v2": false,
|
||||
"device": "cpu",
|
||||
"process": [
|
||||
{
|
||||
"type": "fine_tune"
|
||||
}
|
||||
]
|
||||
},
|
||||
"meta": {
|
||||
"name": "[name]",
|
||||
"description": "A short description of your model",
|
||||
"trigger_words": [
|
||||
"put",
|
||||
"trigger",
|
||||
"words",
|
||||
"here"
|
||||
],
|
||||
"version": "0.1",
|
||||
"creator": {
|
||||
"name": "Your Name",
|
||||
"email": "your@email.com",
|
||||
"website": "https://yourwebsite.com"
|
||||
},
|
||||
"any": "All meta data above is arbitrary, it can be whatever you want."
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,7 @@
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
from jobs.process import BaseProcess
|
||||
|
||||
|
||||
class BaseJob:
|
||||
@@ -6,6 +9,7 @@ class BaseJob:
|
||||
job: str
|
||||
name: str
|
||||
meta: OrderedDict
|
||||
process: List[BaseProcess]
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
if not config:
|
||||
@@ -37,6 +41,25 @@ class BaseJob:
|
||||
# be sure to call super().run() first
|
||||
pass
|
||||
|
||||
def load_processes(self, process_dict: dict):
|
||||
# only call if you have processes in this job type
|
||||
if 'process' not in self.config:
|
||||
raise ValueError('config file is invalid. Missing "config.process" key')
|
||||
if len(self.config['process']) == 0:
|
||||
raise ValueError('config file is invalid. "config.process" must be a list of processes')
|
||||
|
||||
# add the processes
|
||||
self.process = []
|
||||
for i, process in enumerate(self.config['process']):
|
||||
if 'type' not in process:
|
||||
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
|
||||
|
||||
# check if dict key is process type
|
||||
if process['type'] in process_dict:
|
||||
self.process.append(process_dict[process['type']](i, self, process))
|
||||
else:
|
||||
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
|
||||
|
||||
def cleanup(self):
|
||||
# if you implement this in child clas,
|
||||
# be sure to call super().cleanup() LAST
|
||||
|
||||
@@ -5,6 +5,12 @@ from typing import List
|
||||
|
||||
from jobs.process import BaseExtractProcess
|
||||
|
||||
from jobs.process import ExtractLoconProcess
|
||||
|
||||
process_dict = {
|
||||
'locon': ExtractLoconProcess,
|
||||
}
|
||||
|
||||
|
||||
class ExtractJob(BaseJob):
|
||||
process: List[BaseExtractProcess]
|
||||
@@ -19,21 +25,8 @@ class ExtractJob(BaseJob):
|
||||
self.is_v2 = self.get_conf('is_v2', False)
|
||||
self.device = self.get_conf('device', 'cpu')
|
||||
|
||||
if 'process' not in self.config:
|
||||
raise ValueError('config file is invalid. Missing "config.process" key')
|
||||
if len(self.config['process']) == 0:
|
||||
raise ValueError('config file is invalid. "config.process" must be a list of processes')
|
||||
|
||||
# add the processes
|
||||
self.process = []
|
||||
for i, process in enumerate(self.config['process']):
|
||||
if 'type' not in process:
|
||||
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
|
||||
if process['type'] == 'locon':
|
||||
from jobs.process import LoconExtractProcess
|
||||
self.process.append(LoconExtractProcess(i, self, process))
|
||||
else:
|
||||
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
|
||||
# loads the processes from the config
|
||||
self.load_processes(process_dict)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -50,4 +43,3 @@ class ExtractJob(BaseJob):
|
||||
|
||||
for process in self.process:
|
||||
process.run()
|
||||
|
||||
|
||||
38
jobs/TrainJob.py
Normal file
38
jobs/TrainJob.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||
from .BaseJob import BaseJob
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
||||
|
||||
process_dict = {
|
||||
'fine_tine': TrainFineTuneProcess
|
||||
}
|
||||
|
||||
|
||||
class TrainJob(BaseJob):
|
||||
process: List[BaseExtractProcess]
|
||||
|
||||
def __init__(self, config: OrderedDict):
|
||||
super().__init__(config)
|
||||
self.base_model_path = self.get_conf('base_model', required=True)
|
||||
self.base_model = None
|
||||
self.training_folder = self.get_conf('training_folder', required=True)
|
||||
self.is_v2 = self.get_conf('is_v2', False)
|
||||
self.device = self.get_conf('device', 'cpu')
|
||||
|
||||
# loads the processes from the config
|
||||
self.load_processes(process_dict)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
# load models
|
||||
print(f"Loading base model for training")
|
||||
print(f" - Loading base model: {self.base_model_path}")
|
||||
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
|
||||
|
||||
print("")
|
||||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||
|
||||
for process in self.process:
|
||||
process.run()
|
||||
@@ -1,2 +1,3 @@
|
||||
from .BaseJob import BaseJob
|
||||
from .ExtractJob import ExtractJob
|
||||
from .TrainJob import TrainJob
|
||||
|
||||
25
jobs/process/BaseTrainProcess.py
Normal file
25
jobs/process/BaseTrainProcess.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from collections import OrderedDict
|
||||
from jobs import TrainJob
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
|
||||
|
||||
class BaseTrainProcess(BaseProcess):
|
||||
job: TrainJob
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int,
|
||||
job: TrainJob,
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
|
||||
def run(self):
|
||||
# implement in child class
|
||||
# be sure to call super().run() first
|
||||
pass
|
||||
@@ -27,7 +27,7 @@ mode_dict = {
|
||||
}
|
||||
|
||||
|
||||
class LoconExtractProcess(BaseExtractProcess):
|
||||
class ExtractLoconProcess(BaseExtractProcess):
|
||||
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.mode = self.get_conf('mode', 'fixed')
|
||||
13
jobs/process/TrainFineTuneProcess.py
Normal file
13
jobs/process/TrainFineTuneProcess.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from collections import OrderedDict
|
||||
from jobs import TrainJob
|
||||
from jobs.process import BaseTrainProcess
|
||||
|
||||
|
||||
class TrainFineTuneProcess(BaseTrainProcess):
|
||||
def __init__(self,process_id: int, job: TrainJob, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
|
||||
def run(self):
|
||||
# implement in child class
|
||||
# be sure to call super().run() first
|
||||
pass
|
||||
@@ -1,3 +1,5 @@
|
||||
from .BaseExtractProcess import BaseExtractProcess
|
||||
from .LoconExtractProcess import LoconExtractProcess
|
||||
from .ExtractLoconProcess import ExtractLoconProcess
|
||||
from .BaseProcess import BaseProcess
|
||||
from .BaseTrainProcess import BaseTrainProcess
|
||||
from .TrainFineTuneProcess import TrainFineTuneProcess
|
||||
|
||||
1
repositories/sd-scripts
Submodule
1
repositories/sd-scripts
Submodule
Submodule repositories/sd-scripts added at 0cfcb5a49c
@@ -11,5 +11,8 @@ def get_job(config_path) -> BaseJob:
|
||||
if job == 'extract':
|
||||
from jobs import ExtractJob
|
||||
return ExtractJob(config)
|
||||
elif job == 'train':
|
||||
from jobs import TrainJob
|
||||
return TrainJob(config)
|
||||
else:
|
||||
raise ValueError(f'Unknown job type {job}')
|
||||
|
||||
Reference in New Issue
Block a user