Moved some of the job config into base process so it will be easier to extend extensions

This commit is contained in:
Jaret Burkett
2023-08-10 12:14:05 -06:00
parent fbc8a87a05
commit df48f0a843
6 changed files with 43 additions and 30 deletions

View File

@@ -21,6 +21,7 @@ process_dict = {
'lora_hack': 'TrainLoRAHack', 'lora_hack': 'TrainLoRAHack',
'rescale_sd': 'TrainSDRescaleProcess', 'rescale_sd': 'TrainSDRescaleProcess',
'esrgan': 'TrainESRGANProcess', 'esrgan': 'TrainESRGANProcess',
'reference': 'TrainReferenceProcess',
} }
@@ -36,18 +37,9 @@ class TrainJob(BaseJob):
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 # self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
self.log_dir = self.get_conf('log_dir', None) self.log_dir = self.get_conf('log_dir', None)
self.writer = None
self.setup_tensorboard()
# loads the processes from the config # loads the processes from the config
self.load_processes(process_dict) self.load_processes(process_dict)
def save_training_config(self):
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs(self.training_folder, exist_ok=True)
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
with open(save_dif, 'w') as f:
yaml.dump(self.raw_config, f)
def run(self): def run(self):
super().run() super().run()
@@ -56,12 +48,3 @@ class TrainJob(BaseJob):
for process in self.process: for process in self.process:
process.run() process.run()
def setup_tensorboard(self):
if self.log_dir:
from torch.utils.tensorboard import SummaryWriter
now = datetime.now()
time_str = now.strftime('%Y%m%d-%H%M%S')
summary_name = f"{self.name}_{time_str}"
summary_dir = os.path.join(self.log_dir, summary_name)
self.writer = SummaryWriter(summary_dir)

View File

@@ -15,6 +15,8 @@ class BaseProcess(object):
self.process_id = process_id self.process_id = process_id
self.job = job self.job = job
self.config = config self.config = config
self.raw_process_config = config
self.name = self.get_conf('name', self.job.name)
self.meta = copy.deepcopy(self.job.meta) self.meta = copy.deepcopy(self.job.meta)
print(json.dumps(self.config, indent=4)) print(json.dumps(self.config, indent=4))

View File

@@ -40,7 +40,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network_config = NetworkConfig(**network_config) self.network_config = NetworkConfig(**network_config)
else: else:
self.network_config = None self.network_config = None
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
self.train_config = TrainConfig(**self.get_conf('train', {})) self.train_config = TrainConfig(**self.get_conf('train', {}))
self.model_config = ModelConfig(**self.get_conf('model', {})) self.model_config = ModelConfig(**self.get_conf('model', {}))
self.save_config = SaveConfig(**self.get_conf('save', {})) self.save_config = SaveConfig(**self.get_conf('save', {}))
@@ -320,8 +319,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet.train() unet.train()
params += unet.parameters() params += unet.parameters()
# TODO recover save if training network. Maybe load from beginning
### HOOK ### ### HOOK ###
params = self.hook_add_extra_train_params(params) params = self.hook_add_extra_train_params(params)

View File

@@ -1,14 +1,24 @@
from datetime import datetime
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import ForwardRef from typing import TYPE_CHECKING, Union
import yaml
from jobs.process.BaseProcess import BaseProcess from jobs.process.BaseProcess import BaseProcess
if TYPE_CHECKING:
from jobs import TrainJob, BaseJob, ExtensionJob
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class BaseTrainProcess(BaseProcess): class BaseTrainProcess(BaseProcess):
process_id: int process_id: int
config: OrderedDict config: OrderedDict
progress_bar: ForwardRef('tqdm') = None writer: 'SummaryWriter'
job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
progress_bar: 'tqdm' = None
def __init__( def __init__(
self, self,
@@ -18,11 +28,14 @@ class BaseTrainProcess(BaseProcess):
): ):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
self.progress_bar = None self.progress_bar = None
self.writer = self.job.writer self.writer = None
self.training_folder = self.get_conf('training_folder', self.job.training_folder) self.training_folder = self.get_conf('training_folder', self.job.training_folder)
self.save_root = os.path.join(self.training_folder, self.job.name) self.save_root = os.path.join(self.training_folder, self.job.name)
self.step = 0 self.step = 0
self.first_step = 0 self.first_step = 0
self.log_dir = self.get_conf('log_dir', self.job.log_dir)
self.setup_tensorboard()
self.save_training_config()
def run(self): def run(self):
super().run() super().run()
@@ -37,3 +50,19 @@ class BaseTrainProcess(BaseProcess):
self.progress_bar.update() self.progress_bar.update()
else: else:
print(*args) print(*args)
def setup_tensorboard(self):
if self.log_dir:
from torch.utils.tensorboard import SummaryWriter
now = datetime.now()
time_str = now.strftime('%Y%m%d-%H%M%S')
summary_name = f"{self.name}_{time_str}"
summary_dir = os.path.join(self.log_dir, summary_name)
self.writer = SummaryWriter(summary_dir)
def save_training_config(self):
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs(self.training_folder, exist_ok=True)
save_dif = os.path.join(self.training_folder, f'process_config_{timestamp}.yaml')
with open(save_dif, 'w') as f:
yaml.dump(self.raw_process_config, f)

View File

@@ -52,18 +52,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
pass pass
def hook_before_train_loop(self): def hook_before_train_loop(self):
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
# read line by line from file # read line by line from file
if self.slider_config.prompt_file: if self.slider_config.prompt_file:
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f: with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
self.prompt_txt_list = f.readlines() self.prompt_txt_list = f.readlines()
# clean empty lines # clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") self.print(f"Found {len(self.prompt_txt_list)} prompts.")
if not self.slider_config.prompt_tensors: if not self.slider_config.prompt_tensors:
print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.")
# shuffle # shuffle
random.shuffle(self.prompt_txt_list) random.shuffle(self.prompt_txt_list)
# trim to max steps # trim to max steps

View File

@@ -56,6 +56,7 @@ class EncodedPromptPair:
# simulate torch to for tensors # simulate torch to for tensors
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
self.target_class = self.target_class.to(*args, **kwargs) self.target_class = self.target_class.to(*args, **kwargs)
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
self.positive_target = self.positive_target.to(*args, **kwargs) self.positive_target = self.positive_target.to(*args, **kwargs)
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
self.negative_target = self.negative_target.to(*args, **kwargs) self.negative_target = self.negative_target.to(*args, **kwargs)
@@ -308,7 +309,7 @@ def build_prompt_pair_batch_from_cache(
prompt_pair_batch = [] prompt_pair_batch = []
if both or erase_negative: if both or erase_negative:
print("Encoding erase negative") # print("Encoding erase negative")
prompt_pair_batch += [ prompt_pair_batch += [
# erase standard # erase standard
EncodedPromptPair( EncodedPromptPair(
@@ -327,7 +328,7 @@ def build_prompt_pair_batch_from_cache(
), ),
] ]
if both or enhance_positive: if both or enhance_positive:
print("Encoding enhance positive") # print("Encoding enhance positive")
prompt_pair_batch += [ prompt_pair_batch += [
# enhance standard, swap pos neg # enhance standard, swap pos neg
EncodedPromptPair( EncodedPromptPair(
@@ -346,7 +347,7 @@ def build_prompt_pair_batch_from_cache(
), ),
] ]
if both or enhance_positive: if both or enhance_positive:
print("Encoding erase positive (inverse)") # print("Encoding erase positive (inverse)")
prompt_pair_batch += [ prompt_pair_batch += [
# erase inverted # erase inverted
EncodedPromptPair( EncodedPromptPair(
@@ -365,7 +366,7 @@ def build_prompt_pair_batch_from_cache(
), ),
] ]
if both or erase_negative: if both or erase_negative:
print("Encoding enhance negative (inverse)") # print("Encoding enhance negative (inverse)")
prompt_pair_batch += [ prompt_pair_batch += [
# enhance inverted # enhance inverted
EncodedPromptPair( EncodedPromptPair(