mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
import json
|
|
import os
|
|
|
|
from jobs import BaseJob
|
|
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
|
from collections import OrderedDict
|
|
from typing import List
|
|
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
|
from datetime import datetime
|
|
import yaml
|
|
from toolkit.paths import REPOS_ROOT
|
|
|
|
import sys
|
|
|
|
sys.path.append(REPOS_ROOT)
|
|
|
|
process_dict = {
|
|
'vae': 'TrainVAEProcess',
|
|
'slider': 'TrainSliderProcess',
|
|
'lora_hack': 'TrainLoRAHack',
|
|
}
|
|
|
|
|
|
class TrainJob(BaseJob):
|
|
process: List[BaseExtractProcess]
|
|
|
|
def __init__(self, config: OrderedDict):
|
|
super().__init__(config)
|
|
self.training_folder = self.get_conf('training_folder', required=True)
|
|
self.is_v2 = self.get_conf('is_v2', False)
|
|
self.device = self.get_conf('device', 'cpu')
|
|
# self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1)
|
|
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
|
self.log_dir = self.get_conf('log_dir', None)
|
|
|
|
self.writer = None
|
|
self.setup_tensorboard()
|
|
|
|
# loads the processes from the config
|
|
self.load_processes(process_dict)
|
|
|
|
def save_training_config(self):
|
|
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
|
os.makedirs(self.training_folder, exist_ok=True)
|
|
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
|
|
with open(save_dif, 'w') as f:
|
|
yaml.dump(self.raw_config, f)
|
|
|
|
def run(self):
|
|
super().run()
|
|
print("")
|
|
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
|
|
|
|
for process in self.process:
|
|
process.run()
|
|
|
|
def setup_tensorboard(self):
|
|
if self.log_dir:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
now = datetime.now()
|
|
time_str = now.strftime('%Y%m%d-%H%M%S')
|
|
summary_name = f"{self.name}_{time_str}"
|
|
summary_dir = os.path.join(self.log_dir, summary_name)
|
|
self.writer = SummaryWriter(summary_dir)
|