diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index fc594edf..31ed51ef 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -21,6 +21,7 @@ process_dict = { 'lora_hack': 'TrainLoRAHack', 'rescale_sd': 'TrainSDRescaleProcess', 'esrgan': 'TrainESRGANProcess', + 'reference': 'TrainReferenceProcess', } @@ -36,18 +37,9 @@ class TrainJob(BaseJob): # self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 self.log_dir = self.get_conf('log_dir', None) - self.writer = None - self.setup_tensorboard() - # loads the processes from the config self.load_processes(process_dict) - def save_training_config(self): - timestamp = datetime.now().strftime('%Y%m%d-%H%M%S') - os.makedirs(self.training_folder, exist_ok=True) - save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml') - with open(save_dif, 'w') as f: - yaml.dump(self.raw_config, f) def run(self): super().run() @@ -56,12 +48,3 @@ class TrainJob(BaseJob): for process in self.process: process.run() - - def setup_tensorboard(self): - if self.log_dir: - from torch.utils.tensorboard import SummaryWriter - now = datetime.now() - time_str = now.strftime('%Y%m%d-%H%M%S') - summary_name = f"{self.name}_{time_str}" - summary_dir = os.path.join(self.log_dir, summary_name) - self.writer = SummaryWriter(summary_dir) diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py index 167cdc14..c21e7ae7 100644 --- a/jobs/process/BaseProcess.py +++ b/jobs/process/BaseProcess.py @@ -15,6 +15,8 @@ class BaseProcess(object): self.process_id = process_id self.job = job self.config = config + self.raw_process_config = config + self.name = self.get_conf('name', self.job.name) self.meta = copy.deepcopy(self.job.meta) print(json.dumps(self.config, indent=4)) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 8fd652a8..e50ab66e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -40,7 +40,6 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network_config = NetworkConfig(**network_config) else: self.network_config = None - self.training_folder = self.get_conf('training_folder', self.job.training_folder) self.train_config = TrainConfig(**self.get_conf('train', {})) self.model_config = ModelConfig(**self.get_conf('model', {})) self.save_config = SaveConfig(**self.get_conf('save', {})) @@ -320,8 +319,6 @@ class BaseSDTrainProcess(BaseTrainProcess): unet.train() params += unet.parameters() - # TODO recover save if training network. Maybe load from beginning - ### HOOK ### params = self.hook_add_extra_train_params(params) diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py index c80b9e79..cd6f8619 100644 --- a/jobs/process/BaseTrainProcess.py +++ b/jobs/process/BaseTrainProcess.py @@ -1,14 +1,24 @@ +from datetime import datetime import os from collections import OrderedDict -from typing import ForwardRef +from typing import TYPE_CHECKING, Union + +import yaml from jobs.process.BaseProcess import BaseProcess +if TYPE_CHECKING: + from jobs import TrainJob, BaseJob, ExtensionJob + from torch.utils.tensorboard import SummaryWriter + from tqdm import tqdm + class BaseTrainProcess(BaseProcess): process_id: int config: OrderedDict - progress_bar: ForwardRef('tqdm') = None + writer: 'SummaryWriter' + job: Union['TrainJob', 'BaseJob', 'ExtensionJob'] + progress_bar: 'tqdm' = None def __init__( self, @@ -18,11 +28,14 @@ class BaseTrainProcess(BaseProcess): ): super().__init__(process_id, job, config) self.progress_bar = None - self.writer = self.job.writer + self.writer = None self.training_folder = self.get_conf('training_folder', self.job.training_folder) self.save_root = os.path.join(self.training_folder, self.job.name) self.step = 0 self.first_step = 0 + self.log_dir = self.get_conf('log_dir', self.job.log_dir) + self.setup_tensorboard() + self.save_training_config() def run(self): super().run() @@ -37,3 +50,19 @@ class BaseTrainProcess(BaseProcess): self.progress_bar.update() else: print(*args) + + def setup_tensorboard(self): + if self.log_dir: + from torch.utils.tensorboard import SummaryWriter + now = datetime.now() + time_str = now.strftime('%Y%m%d-%H%M%S') + summary_name = f"{self.name}_{time_str}" + summary_dir = os.path.join(self.log_dir, summary_name) + self.writer = SummaryWriter(summary_dir) + + def save_training_config(self): + timestamp = datetime.now().strftime('%Y%m%d-%H%M%S') + os.makedirs(self.training_folder, exist_ok=True) + save_dif = os.path.join(self.training_folder, f'process_config_{timestamp}.yaml') + with open(save_dif, 'w') as f: + yaml.dump(self.raw_process_config, f) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index e10ec5b2..a79e67dc 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -52,18 +52,19 @@ class TrainSliderProcess(BaseSDTrainProcess): pass def hook_before_train_loop(self): - self.print(f"Loading prompt file from {self.slider_config.prompt_file}") # read line by line from file if self.slider_config.prompt_file: + self.print(f"Loading prompt file from {self.slider_config.prompt_file}") with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f: self.prompt_txt_list = f.readlines() # clean empty lines self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] - self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") + self.print(f"Found {len(self.prompt_txt_list)} prompts.") if not self.slider_config.prompt_tensors: + print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.") # shuffle random.shuffle(self.prompt_txt_list) # trim to max steps diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 215d150e..ea999274 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -56,6 +56,7 @@ class EncodedPromptPair: # simulate torch to for tensors def to(self, *args, **kwargs): self.target_class = self.target_class.to(*args, **kwargs) + self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs) self.positive_target = self.positive_target.to(*args, **kwargs) self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs) self.negative_target = self.negative_target.to(*args, **kwargs) @@ -308,7 +309,7 @@ def build_prompt_pair_batch_from_cache( prompt_pair_batch = [] if both or erase_negative: - print("Encoding erase negative") + # print("Encoding erase negative") prompt_pair_batch += [ # erase standard EncodedPromptPair( @@ -327,7 +328,7 @@ def build_prompt_pair_batch_from_cache( ), ] if both or enhance_positive: - print("Encoding enhance positive") + # print("Encoding enhance positive") prompt_pair_batch += [ # enhance standard, swap pos neg EncodedPromptPair( @@ -346,7 +347,7 @@ def build_prompt_pair_batch_from_cache( ), ] if both or enhance_positive: - print("Encoding erase positive (inverse)") + # print("Encoding erase positive (inverse)") prompt_pair_batch += [ # erase inverted EncodedPromptPair( @@ -365,7 +366,7 @@ def build_prompt_pair_batch_from_cache( ), ] if both or erase_negative: - print("Encoding enhance negative (inverse)") + # print("Encoding enhance negative (inverse)") prompt_pair_batch += [ # enhance inverted EncodedPromptPair(