Moved some of the job config into base process so it will be easier to extend extensions

This commit is contained in:
Jaret Burkett
2023-08-10 12:14:05 -06:00
parent fbc8a87a05
commit df48f0a843
6 changed files with 43 additions and 30 deletions

View File

@@ -21,6 +21,7 @@ process_dict = {
'lora_hack': 'TrainLoRAHack',
'rescale_sd': 'TrainSDRescaleProcess',
'esrgan': 'TrainESRGANProcess',
'reference': 'TrainReferenceProcess',
}
@@ -36,18 +37,9 @@ class TrainJob(BaseJob):
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
self.log_dir = self.get_conf('log_dir', None)
self.writer = None
self.setup_tensorboard()
# loads the processes from the config
self.load_processes(process_dict)
def save_training_config(self):
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs(self.training_folder, exist_ok=True)
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
with open(save_dif, 'w') as f:
yaml.dump(self.raw_config, f)
def run(self):
super().run()
@@ -56,12 +48,3 @@ class TrainJob(BaseJob):
for process in self.process:
process.run()
def setup_tensorboard(self):
if self.log_dir:
from torch.utils.tensorboard import SummaryWriter
now = datetime.now()
time_str = now.strftime('%Y%m%d-%H%M%S')
summary_name = f"{self.name}_{time_str}"
summary_dir = os.path.join(self.log_dir, summary_name)
self.writer = SummaryWriter(summary_dir)

View File

@@ -15,6 +15,8 @@ class BaseProcess(object):
self.process_id = process_id
self.job = job
self.config = config
self.raw_process_config = config
self.name = self.get_conf('name', self.job.name)
self.meta = copy.deepcopy(self.job.meta)
print(json.dumps(self.config, indent=4))

View File

@@ -40,7 +40,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network_config = NetworkConfig(**network_config)
else:
self.network_config = None
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
self.train_config = TrainConfig(**self.get_conf('train', {}))
self.model_config = ModelConfig(**self.get_conf('model', {}))
self.save_config = SaveConfig(**self.get_conf('save', {}))
@@ -320,8 +319,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet.train()
params += unet.parameters()
# TODO recover save if training network. Maybe load from beginning
### HOOK ###
params = self.hook_add_extra_train_params(params)

View File

@@ -1,14 +1,24 @@
from datetime import datetime
import os
from collections import OrderedDict
from typing import ForwardRef
from typing import TYPE_CHECKING, Union
import yaml
from jobs.process.BaseProcess import BaseProcess
if TYPE_CHECKING:
from jobs import TrainJob, BaseJob, ExtensionJob
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class BaseTrainProcess(BaseProcess):
process_id: int
config: OrderedDict
progress_bar: ForwardRef('tqdm') = None
writer: 'SummaryWriter'
job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
progress_bar: 'tqdm' = None
def __init__(
self,
@@ -18,11 +28,14 @@ class BaseTrainProcess(BaseProcess):
):
super().__init__(process_id, job, config)
self.progress_bar = None
self.writer = self.job.writer
self.writer = None
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
self.save_root = os.path.join(self.training_folder, self.job.name)
self.step = 0
self.first_step = 0
self.log_dir = self.get_conf('log_dir', self.job.log_dir)
self.setup_tensorboard()
self.save_training_config()
def run(self):
super().run()
@@ -37,3 +50,19 @@ class BaseTrainProcess(BaseProcess):
self.progress_bar.update()
else:
print(*args)
def setup_tensorboard(self):
if self.log_dir:
from torch.utils.tensorboard import SummaryWriter
now = datetime.now()
time_str = now.strftime('%Y%m%d-%H%M%S')
summary_name = f"{self.name}_{time_str}"
summary_dir = os.path.join(self.log_dir, summary_name)
self.writer = SummaryWriter(summary_dir)
def save_training_config(self):
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs(self.training_folder, exist_ok=True)
save_dif = os.path.join(self.training_folder, f'process_config_{timestamp}.yaml')
with open(save_dif, 'w') as f:
yaml.dump(self.raw_process_config, f)

View File

@@ -52,18 +52,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
pass
def hook_before_train_loop(self):
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
# read line by line from file
if self.slider_config.prompt_file:
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
self.print(f"Found {len(self.prompt_txt_list)} prompts.")
if not self.slider_config.prompt_tensors:
print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.")
# shuffle
random.shuffle(self.prompt_txt_list)
# trim to max steps

View File

@@ -56,6 +56,7 @@ class EncodedPromptPair:
# simulate torch to for tensors
def to(self, *args, **kwargs):
self.target_class = self.target_class.to(*args, **kwargs)
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
self.positive_target = self.positive_target.to(*args, **kwargs)
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
self.negative_target = self.negative_target.to(*args, **kwargs)
@@ -308,7 +309,7 @@ def build_prompt_pair_batch_from_cache(
prompt_pair_batch = []
if both or erase_negative:
print("Encoding erase negative")
# print("Encoding erase negative")
prompt_pair_batch += [
# erase standard
EncodedPromptPair(
@@ -327,7 +328,7 @@ def build_prompt_pair_batch_from_cache(
),
]
if both or enhance_positive:
print("Encoding enhance positive")
# print("Encoding enhance positive")
prompt_pair_batch += [
# enhance standard, swap pos neg
EncodedPromptPair(
@@ -346,7 +347,7 @@ def build_prompt_pair_batch_from_cache(
),
]
if both or enhance_positive:
print("Encoding erase positive (inverse)")
# print("Encoding erase positive (inverse)")
prompt_pair_batch += [
# erase inverted
EncodedPromptPair(
@@ -365,7 +366,7 @@ def build_prompt_pair_batch_from_cache(
),
]
if both or erase_negative:
print("Encoding enhance negative (inverse)")
# print("Encoding enhance negative (inverse)")
prompt_pair_batch += [
# enhance inverted
EncodedPromptPair(