Files
ai-toolkit/jobs/TrainJob.py

56 lines
1.7 KiB
Python

import os
from jobs import BaseJob
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from collections import OrderedDict
from typing import List
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
from datetime import datetime
from toolkit.paths import REPOS_ROOT
import sys
sys.path.append(REPOS_ROOT)
process_dict = {
'vae': 'TrainVAEProcess',
'slider': 'TrainSliderProcess',
}
class TrainJob(BaseJob):
process: List[BaseExtractProcess]
def __init__(self, config: OrderedDict):
super().__init__(config)
self.training_folder = self.get_conf('training_folder', required=True)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')
# self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1)
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
self.log_dir = self.get_conf('log_dir', None)
self.writer = None
self.setup_tensorboard()
# loads the processes from the config
self.load_processes(process_dict)
def run(self):
super().run()
print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
for process in self.process:
process.run()
def setup_tensorboard(self):
if self.log_dir:
from torch.utils.tensorboard import SummaryWriter
now = datetime.now()
time_str = now.strftime('%Y%m%d-%H%M%S')
summary_name = f"{self.name}_{time_str}"
summary_dir = os.path.join(self.log_dir, summary_name)
self.writer = SummaryWriter(summary_dir)