Setup base for training jobs. Added sd-scripts as a submodule

This commit is contained in:
Jaret Burkett
2023-07-08 13:50:59 -06:00
parent 37354b006e
commit 47d094e528
13 changed files with 151 additions and 18 deletions

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "repositories/sd-scripts"]
path = repositories/sd-scripts
url = https://github.com/kohya-ss/sd-scripts.git

View File

@@ -0,0 +1,32 @@
{
"job": "train",
"config": {
"name": "name_of_your_model",
"base_model": "/path/to/base/model",
"training_folder": "/path/to/output/folder",
"is_v2": false,
"device": "cpu",
"process": [
{
"type": "fine_tune"
}
]
},
"meta": {
"name": "[name]",
"description": "A short description of your model",
"trigger_words": [
"put",
"trigger",
"words",
"here"
],
"version": "0.1",
"creator": {
"name": "Your Name",
"email": "your@email.com",
"website": "https://yourwebsite.com"
},
"any": "All meta data above is arbitrary, it can be whatever you want."
}
}

View File

@@ -1,4 +1,7 @@
from collections import OrderedDict
from typing import List
from jobs.process import BaseProcess
class BaseJob:
@@ -6,6 +9,7 @@ class BaseJob:
job: str
name: str
meta: OrderedDict
process: List[BaseProcess]
def __init__(self, config: OrderedDict):
if not config:
@@ -37,6 +41,25 @@ class BaseJob:
# be sure to call super().run() first
pass
def load_processes(self, process_dict: dict):
# only call if you have processes in this job type
if 'process' not in self.config:
raise ValueError('config file is invalid. Missing "config.process" key')
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')
# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
if 'type' not in process:
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
# check if dict key is process type
if process['type'] in process_dict:
self.process.append(process_dict[process['type']](i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
def cleanup(self):
# if you implement this in child clas,
# be sure to call super().cleanup() LAST

View File

@@ -5,6 +5,12 @@ from typing import List
from jobs.process import BaseExtractProcess
from jobs.process import ExtractLoconProcess
process_dict = {
'locon': ExtractLoconProcess,
}
class ExtractJob(BaseJob):
process: List[BaseExtractProcess]
@@ -19,21 +25,8 @@ class ExtractJob(BaseJob):
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')
if 'process' not in self.config:
raise ValueError('config file is invalid. Missing "config.process" key')
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')
# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
if 'type' not in process:
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
if process['type'] == 'locon':
from jobs.process import LoconExtractProcess
self.process.append(LoconExtractProcess(i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
# loads the processes from the config
self.load_processes(process_dict)
def run(self):
super().run()
@@ -50,4 +43,3 @@ class ExtractJob(BaseJob):
for process in self.process:
process.run()

38
jobs/TrainJob.py Normal file
View File

@@ -0,0 +1,38 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from .BaseJob import BaseJob
from collections import OrderedDict
from typing import List
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
process_dict = {
'fine_tine': TrainFineTuneProcess
}
class TrainJob(BaseJob):
process: List[BaseExtractProcess]
def __init__(self, config: OrderedDict):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.training_folder = self.get_conf('training_folder', required=True)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')
# loads the processes from the config
self.load_processes(process_dict)
def run(self):
super().run()
# load models
print(f"Loading base model for training")
print(f" - Loading base model: {self.base_model_path}")
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
for process in self.process:
process.run()

View File

@@ -1,2 +1,3 @@
from .BaseJob import BaseJob
from .ExtractJob import ExtractJob
from .TrainJob import TrainJob

View File

@@ -0,0 +1,25 @@
from collections import OrderedDict
from jobs import TrainJob
from jobs.process.BaseProcess import BaseProcess
class BaseTrainProcess(BaseProcess):
job: TrainJob
process_id: int
config: OrderedDict
def __init__(
self,
process_id: int,
job: TrainJob,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.process_id = process_id
self.job = job
self.config = config
def run(self):
# implement in child class
# be sure to call super().run() first
pass

View File

@@ -27,7 +27,7 @@ mode_dict = {
}
class LoconExtractProcess(BaseExtractProcess):
class ExtractLoconProcess(BaseExtractProcess):
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')

View File

@@ -0,0 +1,13 @@
from collections import OrderedDict
from jobs import TrainJob
from jobs.process import BaseTrainProcess
class TrainFineTuneProcess(BaseTrainProcess):
def __init__(self,process_id: int, job: TrainJob, config: OrderedDict):
super().__init__(process_id, job, config)
def run(self):
# implement in child class
# be sure to call super().run() first
pass

View File

@@ -1,3 +1,5 @@
from .BaseExtractProcess import BaseExtractProcess
from .LoconExtractProcess import LoconExtractProcess
from .ExtractLoconProcess import ExtractLoconProcess
from .BaseProcess import BaseProcess
from .BaseTrainProcess import BaseTrainProcess
from .TrainFineTuneProcess import TrainFineTuneProcess

View File

@@ -11,5 +11,8 @@ def get_job(config_path) -> BaseJob:
if job == 'extract':
from jobs import ExtractJob
return ExtractJob(config)
elif job == 'train':
from jobs import TrainJob
return TrainJob(config)
else:
raise ValueError(f'Unknown job type {job}')