Setup base for training jobs. Added sd-scripts as a submodule

This commit is contained in:
Jaret Burkett
2023-07-08 13:50:59 -06:00
parent 37354b006e
commit 47d094e528
13 changed files with 151 additions and 18 deletions

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "repositories/sd-scripts"]
path = repositories/sd-scripts
url = https://github.com/kohya-ss/sd-scripts.git

View File

@@ -0,0 +1,32 @@
{
"job": "train",
"config": {
"name": "name_of_your_model",
"base_model": "/path/to/base/model",
"training_folder": "/path/to/output/folder",
"is_v2": false,
"device": "cpu",
"process": [
{
"type": "fine_tune"
}
]
},
"meta": {
"name": "[name]",
"description": "A short description of your model",
"trigger_words": [
"put",
"trigger",
"words",
"here"
],
"version": "0.1",
"creator": {
"name": "Your Name",
"email": "your@email.com",
"website": "https://yourwebsite.com"
},
"any": "All meta data above is arbitrary, it can be whatever you want."
}
}

View File

@@ -1,4 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from typing import List
from jobs.process import BaseProcess
class BaseJob: class BaseJob:
@@ -6,6 +9,7 @@ class BaseJob:
job: str job: str
name: str name: str
meta: OrderedDict meta: OrderedDict
process: List[BaseProcess]
def __init__(self, config: OrderedDict): def __init__(self, config: OrderedDict):
if not config: if not config:
@@ -37,6 +41,25 @@ class BaseJob:
# be sure to call super().run() first # be sure to call super().run() first
pass pass
def load_processes(self, process_dict: dict):
# only call if you have processes in this job type
if 'process' not in self.config:
raise ValueError('config file is invalid. Missing "config.process" key')
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')
# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
if 'type' not in process:
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
# check if dict key is process type
if process['type'] in process_dict:
self.process.append(process_dict[process['type']](i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
def cleanup(self): def cleanup(self):
# if you implement this in child clas, # if you implement this in child clas,
# be sure to call super().cleanup() LAST # be sure to call super().cleanup() LAST

View File

@@ -5,6 +5,12 @@ from typing import List
from jobs.process import BaseExtractProcess from jobs.process import BaseExtractProcess
from jobs.process import ExtractLoconProcess
process_dict = {
'locon': ExtractLoconProcess,
}
class ExtractJob(BaseJob): class ExtractJob(BaseJob):
process: List[BaseExtractProcess] process: List[BaseExtractProcess]
@@ -19,21 +25,8 @@ class ExtractJob(BaseJob):
self.is_v2 = self.get_conf('is_v2', False) self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu') self.device = self.get_conf('device', 'cpu')
if 'process' not in self.config: # loads the processes from the config
raise ValueError('config file is invalid. Missing "config.process" key') self.load_processes(process_dict)
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')
# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
if 'type' not in process:
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
if process['type'] == 'locon':
from jobs.process import LoconExtractProcess
self.process.append(LoconExtractProcess(i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
def run(self): def run(self):
super().run() super().run()
@@ -50,4 +43,3 @@ class ExtractJob(BaseJob):
for process in self.process: for process in self.process:
process.run() process.run()

38
jobs/TrainJob.py Normal file
View File

@@ -0,0 +1,38 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from .BaseJob import BaseJob
from collections import OrderedDict
from typing import List
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
process_dict = {
'fine_tine': TrainFineTuneProcess
}
class TrainJob(BaseJob):
process: List[BaseExtractProcess]
def __init__(self, config: OrderedDict):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.training_folder = self.get_conf('training_folder', required=True)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')
# loads the processes from the config
self.load_processes(process_dict)
def run(self):
super().run()
# load models
print(f"Loading base model for training")
print(f" - Loading base model: {self.base_model_path}")
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
for process in self.process:
process.run()

View File

@@ -1,2 +1,3 @@
from .BaseJob import BaseJob from .BaseJob import BaseJob
from .ExtractJob import ExtractJob from .ExtractJob import ExtractJob
from .TrainJob import TrainJob

View File

@@ -0,0 +1,25 @@
from collections import OrderedDict
from jobs import TrainJob
from jobs.process.BaseProcess import BaseProcess
class BaseTrainProcess(BaseProcess):
job: TrainJob
process_id: int
config: OrderedDict
def __init__(
self,
process_id: int,
job: TrainJob,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.process_id = process_id
self.job = job
self.config = config
def run(self):
# implement in child class
# be sure to call super().run() first
pass

View File

@@ -27,7 +27,7 @@ mode_dict = {
} }
class LoconExtractProcess(BaseExtractProcess): class ExtractLoconProcess(BaseExtractProcess):
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict): def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed') self.mode = self.get_conf('mode', 'fixed')

View File

@@ -0,0 +1,13 @@
from collections import OrderedDict
from jobs import TrainJob
from jobs.process import BaseTrainProcess
class TrainFineTuneProcess(BaseTrainProcess):
def __init__(self,process_id: int, job: TrainJob, config: OrderedDict):
super().__init__(process_id, job, config)
def run(self):
# implement in child class
# be sure to call super().run() first
pass

View File

@@ -1,3 +1,5 @@
from .BaseExtractProcess import BaseExtractProcess from .BaseExtractProcess import BaseExtractProcess
from .LoconExtractProcess import LoconExtractProcess from .ExtractLoconProcess import ExtractLoconProcess
from .BaseProcess import BaseProcess from .BaseProcess import BaseProcess
from .BaseTrainProcess import BaseTrainProcess
from .TrainFineTuneProcess import TrainFineTuneProcess

View File

@@ -11,5 +11,8 @@ def get_job(config_path) -> BaseJob:
if job == 'extract': if job == 'extract':
from jobs import ExtractJob from jobs import ExtractJob
return ExtractJob(config) return ExtractJob(config)
elif job == 'train':
from jobs import TrainJob
return TrainJob(config)
else: else:
raise ValueError(f'Unknown job type {job}') raise ValueError(f'Unknown job type {job}')