diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index c9a7ee34..5f971462 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -26,7 +26,7 @@ class TrainJob(BaseJob): self.device = self.get_conf('device', 'cpu') self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 - self.logging_dir = self.get_conf('logging_dir', None) + self.log_dir = self.get_conf('log_dir', None) self.writer = None self.setup_tensorboard() @@ -43,9 +43,9 @@ class TrainJob(BaseJob): process.run() def setup_tensorboard(self): - if self.logging_dir: + if self.log_dir: from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter( - log_dir=self.logging_dir, + log_dir=self.log_dir, filename_suffix=f"_{self.name}" ) diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index f000b077..5804ded7 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -45,21 +45,22 @@ class TrainVAEProcess(BaseTrainProcess): self.vae_path = self.get_conf('vae_path', required=True) self.datasets_objects = self.get_conf('datasets', required=True) self.training_folder = self.get_conf('training_folder', self.job.training_folder) - self.batch_size = self.get_conf('batch_size', 1) - self.resolution = self.get_conf('resolution', 256) - self.learning_rate = self.get_conf('learning_rate', 1e-6) + self.batch_size = self.get_conf('batch_size', 1, as_type=int) + self.resolution = self.get_conf('resolution', 256, as_type=int) + self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) self.sample_every = self.get_conf('sample_every', None) - self.epochs = self.get_conf('epochs', None) - self.max_steps = self.get_conf('max_steps', None) + self.optimizer_type = self.get_conf('optimizer', 'adam') + self.epochs = self.get_conf('epochs', None, as_type=int) + self.max_steps = self.get_conf('max_steps', None, as_type=int) self.save_every = self.get_conf('save_every', None) self.dtype = self.get_conf('dtype', 'float32') self.sample_sources = self.get_conf('sample_sources', None) - self.log_every = self.get_conf('log_every', 100) - self.style_weight = self.get_conf('style_weight', 0) - self.content_weight = self.get_conf('content_weight', 0) - self.kld_weight = self.get_conf('kld_weight', 0) - self.mse_weight = self.get_conf('mse_weight', 1e0) - self.tv_weight = self.get_conf('tv_weight', 1e0) + self.log_every = self.get_conf('log_every', 100, as_type=int) + self.style_weight = self.get_conf('style_weight', 0, as_type=float) + self.content_weight = self.get_conf('content_weight', 0, as_type=float) + self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) + self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) + self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.writer = self.job.writer @@ -309,7 +310,12 @@ class TrainVAEProcess(BaseTrainProcess): self.vgg_19.eval() # todo allow other optimizers - optimizer = torch.optim.Adam(params, lr=self.learning_rate) + if self.optimizer_type == 'dadaptation': + import dadaptation + print("Using DAdaptAdam optimizer") + optimizer = dadaptation.DAdaptAdam(params, lr=1) + else: + optimizer = torch.optim.Adam(params, lr=float(self.learning_rate)) # setup scheduler # todo allow other schedulers @@ -393,7 +399,13 @@ class TrainVAEProcess(BaseTrainProcess): if self.tv_weight > 0: loss_string += f" tv: {tv_loss.item():.2e}" - learning_rate = optimizer.param_groups[0]['lr'] + if self.optimizer_type.startswith('dadaptation'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}") self.progress_bar.set_description(f"E: {epoch}") self.progress_bar.update(1) diff --git a/requirements.txt b/requirements.txt index f5355be2..a8b4231f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,6 @@ transformers lycoris_lora flatten_json accelerator +pyyaml +oyaml +tensorboard \ No newline at end of file diff --git a/run.py b/run.py index 04730fff..15c57a2a 100644 --- a/run.py +++ b/run.py @@ -27,7 +27,7 @@ def main(): 'config_file_list', nargs='+', type=str, - help='Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' + help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially' ) # flag to continue if failed job diff --git a/toolkit/config.py b/toolkit/config.py index b3116bac..d0d3f15a 100644 --- a/toolkit/config.py +++ b/toolkit/config.py @@ -1,10 +1,11 @@ import os import json +import oyaml as yaml from collections import OrderedDict from toolkit.paths import TOOLKIT_ROOT -possible_extensions = ['.json', '.jsonc'] +possible_extensions = ['.json', '.jsonc', '.yaml', '.yml'] def get_cwd_abs_path(path): @@ -49,8 +50,14 @@ def get_config(config_file_path): if not real_config_path: raise ValueError(f"Could not find config file {config_file_path}") - # load the config - with open(real_config_path, 'r') as f: - config = json.load(f, object_pairs_hook=OrderedDict) + # if we found it, check if it is a json or yaml file + if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'): + with open(real_config_path, 'r') as f: + config = json.load(f, object_pairs_hook=OrderedDict) + elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'): + with open(real_config_path, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + else: + raise ValueError(f"Config file {config_file_path} must be a json or yaml file") return preprocess_config(config) diff --git a/toolkit/style.py b/toolkit/style.py index 52be08ba..01fbec73 100644 --- a/toolkit/style.py +++ b/toolkit/style.py @@ -121,8 +121,9 @@ class Normalization(nn.Module): # normalize to min and max of 0 - 1 in_min = torch.min(stacked_input) in_max = torch.max(stacked_input) - norm_stacked_input = (stacked_input - in_min) / (in_max - in_min) - return (norm_stacked_input - self.mean) / self.std + # norm_stacked_input = (stacked_input - in_min) / (in_max - in_min) + # return (norm_stacked_input - self.mean) / self.std + return (stacked_input - self.mean) / self.std def get_style_model_and_losses(