mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-04 01:59:48 +00:00
Moved some of the job config into base process so it will be easier to extend extensions
This commit is contained in:
@@ -21,6 +21,7 @@ process_dict = {
|
||||
'lora_hack': 'TrainLoRAHack',
|
||||
'rescale_sd': 'TrainSDRescaleProcess',
|
||||
'esrgan': 'TrainESRGANProcess',
|
||||
'reference': 'TrainReferenceProcess',
|
||||
}
|
||||
|
||||
|
||||
@@ -36,18 +37,9 @@ class TrainJob(BaseJob):
|
||||
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
||||
self.log_dir = self.get_conf('log_dir', None)
|
||||
|
||||
self.writer = None
|
||||
self.setup_tensorboard()
|
||||
|
||||
# loads the processes from the config
|
||||
self.load_processes(process_dict)
|
||||
|
||||
def save_training_config(self):
|
||||
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
os.makedirs(self.training_folder, exist_ok=True)
|
||||
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
|
||||
with open(save_dif, 'w') as f:
|
||||
yaml.dump(self.raw_config, f)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -56,12 +48,3 @@ class TrainJob(BaseJob):
|
||||
|
||||
for process in self.process:
|
||||
process.run()
|
||||
|
||||
def setup_tensorboard(self):
|
||||
if self.log_dir:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
now = datetime.now()
|
||||
time_str = now.strftime('%Y%m%d-%H%M%S')
|
||||
summary_name = f"{self.name}_{time_str}"
|
||||
summary_dir = os.path.join(self.log_dir, summary_name)
|
||||
self.writer = SummaryWriter(summary_dir)
|
||||
|
||||
@@ -15,6 +15,8 @@ class BaseProcess(object):
|
||||
self.process_id = process_id
|
||||
self.job = job
|
||||
self.config = config
|
||||
self.raw_process_config = config
|
||||
self.name = self.get_conf('name', self.job.name)
|
||||
self.meta = copy.deepcopy(self.job.meta)
|
||||
print(json.dumps(self.config, indent=4))
|
||||
|
||||
|
||||
@@ -40,7 +40,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network_config = NetworkConfig(**network_config)
|
||||
else:
|
||||
self.network_config = None
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
||||
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||
@@ -320,8 +319,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unet.train()
|
||||
params += unet.parameters()
|
||||
|
||||
# TODO recover save if training network. Maybe load from beginning
|
||||
|
||||
### HOOK ###
|
||||
params = self.hook_add_extra_train_params(params)
|
||||
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import ForwardRef
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jobs import TrainJob, BaseJob, ExtensionJob
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class BaseTrainProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
progress_bar: ForwardRef('tqdm') = None
|
||||
writer: 'SummaryWriter'
|
||||
job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
|
||||
progress_bar: 'tqdm' = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -18,11 +28,14 @@ class BaseTrainProcess(BaseProcess):
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.progress_bar = None
|
||||
self.writer = self.job.writer
|
||||
self.writer = None
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||
self.step = 0
|
||||
self.first_step = 0
|
||||
self.log_dir = self.get_conf('log_dir', self.job.log_dir)
|
||||
self.setup_tensorboard()
|
||||
self.save_training_config()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -37,3 +50,19 @@ class BaseTrainProcess(BaseProcess):
|
||||
self.progress_bar.update()
|
||||
else:
|
||||
print(*args)
|
||||
|
||||
def setup_tensorboard(self):
|
||||
if self.log_dir:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
now = datetime.now()
|
||||
time_str = now.strftime('%Y%m%d-%H%M%S')
|
||||
summary_name = f"{self.name}_{time_str}"
|
||||
summary_dir = os.path.join(self.log_dir, summary_name)
|
||||
self.writer = SummaryWriter(summary_dir)
|
||||
|
||||
def save_training_config(self):
|
||||
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
os.makedirs(self.training_folder, exist_ok=True)
|
||||
save_dif = os.path.join(self.training_folder, f'process_config_{timestamp}.yaml')
|
||||
with open(save_dif, 'w') as f:
|
||||
yaml.dump(self.raw_process_config, f)
|
||||
|
||||
@@ -52,18 +52,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
pass
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
|
||||
|
||||
# read line by line from file
|
||||
if self.slider_config.prompt_file:
|
||||
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
|
||||
with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
self.print(f"Found {len(self.prompt_txt_list)} prompts.")
|
||||
|
||||
if not self.slider_config.prompt_tensors:
|
||||
print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.")
|
||||
# shuffle
|
||||
random.shuffle(self.prompt_txt_list)
|
||||
# trim to max steps
|
||||
|
||||
@@ -56,6 +56,7 @@ class EncodedPromptPair:
|
||||
# simulate torch to for tensors
|
||||
def to(self, *args, **kwargs):
|
||||
self.target_class = self.target_class.to(*args, **kwargs)
|
||||
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
|
||||
self.positive_target = self.positive_target.to(*args, **kwargs)
|
||||
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
|
||||
self.negative_target = self.negative_target.to(*args, **kwargs)
|
||||
@@ -308,7 +309,7 @@ def build_prompt_pair_batch_from_cache(
|
||||
prompt_pair_batch = []
|
||||
|
||||
if both or erase_negative:
|
||||
print("Encoding erase negative")
|
||||
# print("Encoding erase negative")
|
||||
prompt_pair_batch += [
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
@@ -327,7 +328,7 @@ def build_prompt_pair_batch_from_cache(
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
print("Encoding enhance positive")
|
||||
# print("Encoding enhance positive")
|
||||
prompt_pair_batch += [
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
@@ -346,7 +347,7 @@ def build_prompt_pair_batch_from_cache(
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
print("Encoding erase positive (inverse)")
|
||||
# print("Encoding erase positive (inverse)")
|
||||
prompt_pair_batch += [
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
@@ -365,7 +366,7 @@ def build_prompt_pair_batch_from_cache(
|
||||
),
|
||||
]
|
||||
if both or erase_negative:
|
||||
print("Encoding enhance negative (inverse)")
|
||||
# print("Encoding enhance negative (inverse)")
|
||||
prompt_pair_batch += [
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
|
||||
Reference in New Issue
Block a user