mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Setup base for training jobs. Added sd-scripts as a submodule
This commit is contained in:
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "repositories/sd-scripts"]
|
||||||
|
path = repositories/sd-scripts
|
||||||
|
url = https://github.com/kohya-ss/sd-scripts.git
|
||||||
32
config/examples/train.example.json
Normal file
32
config/examples/train.example.json
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"job": "train",
|
||||||
|
"config": {
|
||||||
|
"name": "name_of_your_model",
|
||||||
|
"base_model": "/path/to/base/model",
|
||||||
|
"training_folder": "/path/to/output/folder",
|
||||||
|
"is_v2": false,
|
||||||
|
"device": "cpu",
|
||||||
|
"process": [
|
||||||
|
{
|
||||||
|
"type": "fine_tune"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"meta": {
|
||||||
|
"name": "[name]",
|
||||||
|
"description": "A short description of your model",
|
||||||
|
"trigger_words": [
|
||||||
|
"put",
|
||||||
|
"trigger",
|
||||||
|
"words",
|
||||||
|
"here"
|
||||||
|
],
|
||||||
|
"version": "0.1",
|
||||||
|
"creator": {
|
||||||
|
"name": "Your Name",
|
||||||
|
"email": "your@email.com",
|
||||||
|
"website": "https://yourwebsite.com"
|
||||||
|
},
|
||||||
|
"any": "All meta data above is arbitrary, it can be whatever you want."
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,7 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from jobs.process import BaseProcess
|
||||||
|
|
||||||
|
|
||||||
class BaseJob:
|
class BaseJob:
|
||||||
@@ -6,6 +9,7 @@ class BaseJob:
|
|||||||
job: str
|
job: str
|
||||||
name: str
|
name: str
|
||||||
meta: OrderedDict
|
meta: OrderedDict
|
||||||
|
process: List[BaseProcess]
|
||||||
|
|
||||||
def __init__(self, config: OrderedDict):
|
def __init__(self, config: OrderedDict):
|
||||||
if not config:
|
if not config:
|
||||||
@@ -37,6 +41,25 @@ class BaseJob:
|
|||||||
# be sure to call super().run() first
|
# be sure to call super().run() first
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def load_processes(self, process_dict: dict):
|
||||||
|
# only call if you have processes in this job type
|
||||||
|
if 'process' not in self.config:
|
||||||
|
raise ValueError('config file is invalid. Missing "config.process" key')
|
||||||
|
if len(self.config['process']) == 0:
|
||||||
|
raise ValueError('config file is invalid. "config.process" must be a list of processes')
|
||||||
|
|
||||||
|
# add the processes
|
||||||
|
self.process = []
|
||||||
|
for i, process in enumerate(self.config['process']):
|
||||||
|
if 'type' not in process:
|
||||||
|
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
|
||||||
|
|
||||||
|
# check if dict key is process type
|
||||||
|
if process['type'] in process_dict:
|
||||||
|
self.process.append(process_dict[process['type']](i, self, process))
|
||||||
|
else:
|
||||||
|
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
# if you implement this in child clas,
|
# if you implement this in child clas,
|
||||||
# be sure to call super().cleanup() LAST
|
# be sure to call super().cleanup() LAST
|
||||||
|
|||||||
@@ -5,6 +5,12 @@ from typing import List
|
|||||||
|
|
||||||
from jobs.process import BaseExtractProcess
|
from jobs.process import BaseExtractProcess
|
||||||
|
|
||||||
|
from jobs.process import ExtractLoconProcess
|
||||||
|
|
||||||
|
process_dict = {
|
||||||
|
'locon': ExtractLoconProcess,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ExtractJob(BaseJob):
|
class ExtractJob(BaseJob):
|
||||||
process: List[BaseExtractProcess]
|
process: List[BaseExtractProcess]
|
||||||
@@ -19,21 +25,8 @@ class ExtractJob(BaseJob):
|
|||||||
self.is_v2 = self.get_conf('is_v2', False)
|
self.is_v2 = self.get_conf('is_v2', False)
|
||||||
self.device = self.get_conf('device', 'cpu')
|
self.device = self.get_conf('device', 'cpu')
|
||||||
|
|
||||||
if 'process' not in self.config:
|
# loads the processes from the config
|
||||||
raise ValueError('config file is invalid. Missing "config.process" key')
|
self.load_processes(process_dict)
|
||||||
if len(self.config['process']) == 0:
|
|
||||||
raise ValueError('config file is invalid. "config.process" must be a list of processes')
|
|
||||||
|
|
||||||
# add the processes
|
|
||||||
self.process = []
|
|
||||||
for i, process in enumerate(self.config['process']):
|
|
||||||
if 'type' not in process:
|
|
||||||
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
|
|
||||||
if process['type'] == 'locon':
|
|
||||||
from jobs.process import LoconExtractProcess
|
|
||||||
self.process.append(LoconExtractProcess(i, self, process))
|
|
||||||
else:
|
|
||||||
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
super().run()
|
||||||
@@ -50,4 +43,3 @@ class ExtractJob(BaseJob):
|
|||||||
|
|
||||||
for process in self.process:
|
for process in self.process:
|
||||||
process.run()
|
process.run()
|
||||||
|
|
||||||
|
|||||||
38
jobs/TrainJob.py
Normal file
38
jobs/TrainJob.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||||
|
from .BaseJob import BaseJob
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
||||||
|
|
||||||
|
process_dict = {
|
||||||
|
'fine_tine': TrainFineTuneProcess
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TrainJob(BaseJob):
|
||||||
|
process: List[BaseExtractProcess]
|
||||||
|
|
||||||
|
def __init__(self, config: OrderedDict):
|
||||||
|
super().__init__(config)
|
||||||
|
self.base_model_path = self.get_conf('base_model', required=True)
|
||||||
|
self.base_model = None
|
||||||
|
self.training_folder = self.get_conf('training_folder', required=True)
|
||||||
|
self.is_v2 = self.get_conf('is_v2', False)
|
||||||
|
self.device = self.get_conf('device', 'cpu')
|
||||||
|
|
||||||
|
# loads the processes from the config
|
||||||
|
self.load_processes(process_dict)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
super().run()
|
||||||
|
# load models
|
||||||
|
print(f"Loading base model for training")
|
||||||
|
print(f" - Loading base model: {self.base_model_path}")
|
||||||
|
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
|
||||||
|
|
||||||
|
print("")
|
||||||
|
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
||||||
|
|
||||||
|
for process in self.process:
|
||||||
|
process.run()
|
||||||
@@ -1,2 +1,3 @@
|
|||||||
from .BaseJob import BaseJob
|
from .BaseJob import BaseJob
|
||||||
from .ExtractJob import ExtractJob
|
from .ExtractJob import ExtractJob
|
||||||
|
from .TrainJob import TrainJob
|
||||||
|
|||||||
25
jobs/process/BaseTrainProcess.py
Normal file
25
jobs/process/BaseTrainProcess.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
from jobs import TrainJob
|
||||||
|
from jobs.process.BaseProcess import BaseProcess
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTrainProcess(BaseProcess):
|
||||||
|
job: TrainJob
|
||||||
|
process_id: int
|
||||||
|
config: OrderedDict
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
process_id: int,
|
||||||
|
job: TrainJob,
|
||||||
|
config: OrderedDict
|
||||||
|
):
|
||||||
|
super().__init__(process_id, job, config)
|
||||||
|
self.process_id = process_id
|
||||||
|
self.job = job
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
# implement in child class
|
||||||
|
# be sure to call super().run() first
|
||||||
|
pass
|
||||||
@@ -27,7 +27,7 @@ mode_dict = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LoconExtractProcess(BaseExtractProcess):
|
class ExtractLoconProcess(BaseExtractProcess):
|
||||||
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
|
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
|
||||||
super().__init__(process_id, job, config)
|
super().__init__(process_id, job, config)
|
||||||
self.mode = self.get_conf('mode', 'fixed')
|
self.mode = self.get_conf('mode', 'fixed')
|
||||||
13
jobs/process/TrainFineTuneProcess.py
Normal file
13
jobs/process/TrainFineTuneProcess.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
from jobs import TrainJob
|
||||||
|
from jobs.process import BaseTrainProcess
|
||||||
|
|
||||||
|
|
||||||
|
class TrainFineTuneProcess(BaseTrainProcess):
|
||||||
|
def __init__(self,process_id: int, job: TrainJob, config: OrderedDict):
|
||||||
|
super().__init__(process_id, job, config)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
# implement in child class
|
||||||
|
# be sure to call super().run() first
|
||||||
|
pass
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
from .BaseExtractProcess import BaseExtractProcess
|
from .BaseExtractProcess import BaseExtractProcess
|
||||||
from .LoconExtractProcess import LoconExtractProcess
|
from .ExtractLoconProcess import ExtractLoconProcess
|
||||||
from .BaseProcess import BaseProcess
|
from .BaseProcess import BaseProcess
|
||||||
|
from .BaseTrainProcess import BaseTrainProcess
|
||||||
|
from .TrainFineTuneProcess import TrainFineTuneProcess
|
||||||
|
|||||||
1
repositories/sd-scripts
Submodule
1
repositories/sd-scripts
Submodule
Submodule repositories/sd-scripts added at 0cfcb5a49c
@@ -11,5 +11,8 @@ def get_job(config_path) -> BaseJob:
|
|||||||
if job == 'extract':
|
if job == 'extract':
|
||||||
from jobs import ExtractJob
|
from jobs import ExtractJob
|
||||||
return ExtractJob(config)
|
return ExtractJob(config)
|
||||||
|
elif job == 'train':
|
||||||
|
from jobs import TrainJob
|
||||||
|
return TrainJob(config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown job type {job}')
|
raise ValueError(f'Unknown job type {job}')
|
||||||
|
|||||||
Reference in New Issue
Block a user