diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..9828d447 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "repositories/sd-scripts"] + path = repositories/sd-scripts + url = https://github.com/kohya-ss/sd-scripts.git diff --git a/config/examples/extract_config.example.json b/config/examples/extract.example.json similarity index 100% rename from config/examples/extract_config.example.json rename to config/examples/extract.example.json diff --git a/config/examples/train.example.json b/config/examples/train.example.json new file mode 100644 index 00000000..8d19812d --- /dev/null +++ b/config/examples/train.example.json @@ -0,0 +1,32 @@ +{ + "job": "train", + "config": { + "name": "name_of_your_model", + "base_model": "/path/to/base/model", + "training_folder": "/path/to/output/folder", + "is_v2": false, + "device": "cpu", + "process": [ + { + "type": "fine_tune" + } + ] + }, + "meta": { + "name": "[name]", + "description": "A short description of your model", + "trigger_words": [ + "put", + "trigger", + "words", + "here" + ], + "version": "0.1", + "creator": { + "name": "Your Name", + "email": "your@email.com", + "website": "https://yourwebsite.com" + }, + "any": "All meta data above is arbitrary, it can be whatever you want." + } +} \ No newline at end of file diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py index c5a52c69..ae4d0bc5 100644 --- a/jobs/BaseJob.py +++ b/jobs/BaseJob.py @@ -1,4 +1,7 @@ from collections import OrderedDict +from typing import List + +from jobs.process import BaseProcess class BaseJob: @@ -6,6 +9,7 @@ class BaseJob: job: str name: str meta: OrderedDict + process: List[BaseProcess] def __init__(self, config: OrderedDict): if not config: @@ -37,6 +41,25 @@ class BaseJob: # be sure to call super().run() first pass + def load_processes(self, process_dict: dict): + # only call if you have processes in this job type + if 'process' not in self.config: + raise ValueError('config file is invalid. Missing "config.process" key') + if len(self.config['process']) == 0: + raise ValueError('config file is invalid. "config.process" must be a list of processes') + + # add the processes + self.process = [] + for i, process in enumerate(self.config['process']): + if 'type' not in process: + raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key') + + # check if dict key is process type + if process['type'] in process_dict: + self.process.append(process_dict[process['type']](i, self, process)) + else: + raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') + def cleanup(self): # if you implement this in child clas, # be sure to call super().cleanup() LAST diff --git a/jobs/ExtractJob.py b/jobs/ExtractJob.py index 1b6af8fb..99a86fa5 100644 --- a/jobs/ExtractJob.py +++ b/jobs/ExtractJob.py @@ -5,6 +5,12 @@ from typing import List from jobs.process import BaseExtractProcess +from jobs.process import ExtractLoconProcess + +process_dict = { + 'locon': ExtractLoconProcess, +} + class ExtractJob(BaseJob): process: List[BaseExtractProcess] @@ -19,21 +25,8 @@ class ExtractJob(BaseJob): self.is_v2 = self.get_conf('is_v2', False) self.device = self.get_conf('device', 'cpu') - if 'process' not in self.config: - raise ValueError('config file is invalid. Missing "config.process" key') - if len(self.config['process']) == 0: - raise ValueError('config file is invalid. "config.process" must be a list of processes') - - # add the processes - self.process = [] - for i, process in enumerate(self.config['process']): - if 'type' not in process: - raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key') - if process['type'] == 'locon': - from jobs.process import LoconExtractProcess - self.process.append(LoconExtractProcess(i, self, process)) - else: - raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') + # loads the processes from the config + self.load_processes(process_dict) def run(self): super().run() @@ -50,4 +43,3 @@ class ExtractJob(BaseJob): for process in self.process: process.run() - diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py new file mode 100644 index 00000000..f898e629 --- /dev/null +++ b/jobs/TrainJob.py @@ -0,0 +1,38 @@ +from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint +from .BaseJob import BaseJob +from collections import OrderedDict +from typing import List + +from jobs.process import BaseExtractProcess, TrainFineTuneProcess + +process_dict = { + 'fine_tine': TrainFineTuneProcess +} + + +class TrainJob(BaseJob): + process: List[BaseExtractProcess] + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.base_model_path = self.get_conf('base_model', required=True) + self.base_model = None + self.training_folder = self.get_conf('training_folder', required=True) + self.is_v2 = self.get_conf('is_v2', False) + self.device = self.get_conf('device', 'cpu') + + # loads the processes from the config + self.load_processes(process_dict) + + def run(self): + super().run() + # load models + print(f"Loading base model for training") + print(f" - Loading base model: {self.base_model_path}") + self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/__init__.py b/jobs/__init__.py index 09be1770..9c2472f6 100644 --- a/jobs/__init__.py +++ b/jobs/__init__.py @@ -1,2 +1,3 @@ from .BaseJob import BaseJob from .ExtractJob import ExtractJob +from .TrainJob import TrainJob diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py new file mode 100644 index 00000000..ba1beac2 --- /dev/null +++ b/jobs/process/BaseTrainProcess.py @@ -0,0 +1,25 @@ +from collections import OrderedDict +from jobs import TrainJob +from jobs.process.BaseProcess import BaseProcess + + +class BaseTrainProcess(BaseProcess): + job: TrainJob + process_id: int + config: OrderedDict + + def __init__( + self, + process_id: int, + job: TrainJob, + config: OrderedDict + ): + super().__init__(process_id, job, config) + self.process_id = process_id + self.job = job + self.config = config + + def run(self): + # implement in child class + # be sure to call super().run() first + pass diff --git a/jobs/process/LoconExtractProcess.py b/jobs/process/ExtractLoconProcess.py similarity index 97% rename from jobs/process/LoconExtractProcess.py rename to jobs/process/ExtractLoconProcess.py index c2133bdf..ba2d039f 100644 --- a/jobs/process/LoconExtractProcess.py +++ b/jobs/process/ExtractLoconProcess.py @@ -27,7 +27,7 @@ mode_dict = { } -class LoconExtractProcess(BaseExtractProcess): +class ExtractLoconProcess(BaseExtractProcess): def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict): super().__init__(process_id, job, config) self.mode = self.get_conf('mode', 'fixed') diff --git a/jobs/process/TrainFineTuneProcess.py b/jobs/process/TrainFineTuneProcess.py new file mode 100644 index 00000000..a13a7cf6 --- /dev/null +++ b/jobs/process/TrainFineTuneProcess.py @@ -0,0 +1,13 @@ +from collections import OrderedDict +from jobs import TrainJob +from jobs.process import BaseTrainProcess + + +class TrainFineTuneProcess(BaseTrainProcess): + def __init__(self,process_id: int, job: TrainJob, config: OrderedDict): + super().__init__(process_id, job, config) + + def run(self): + # implement in child class + # be sure to call super().run() first + pass diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index c480e0ac..8879ed3a 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -1,3 +1,5 @@ from .BaseExtractProcess import BaseExtractProcess -from .LoconExtractProcess import LoconExtractProcess +from .ExtractLoconProcess import ExtractLoconProcess from .BaseProcess import BaseProcess +from .BaseTrainProcess import BaseTrainProcess +from .TrainFineTuneProcess import TrainFineTuneProcess diff --git a/repositories/sd-scripts b/repositories/sd-scripts new file mode 160000 index 00000000..0cfcb5a4 --- /dev/null +++ b/repositories/sd-scripts @@ -0,0 +1 @@ +Subproject commit 0cfcb5a49cf813547d728101cc05edf1a9b7d06c diff --git a/toolkit/job.py b/toolkit/job.py index 5ac7e0c5..2d497f57 100644 --- a/toolkit/job.py +++ b/toolkit/job.py @@ -11,5 +11,8 @@ def get_job(config_path) -> BaseJob: if job == 'extract': from jobs import ExtractJob return ExtractJob(config) + elif job == 'train': + from jobs import TrainJob + return TrainJob(config) else: raise ValueError(f'Unknown job type {job}')