mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issues with vgg19 preprocessing. Added yaml support on config file
This commit is contained in:
@@ -26,7 +26,7 @@ class TrainJob(BaseJob):
|
||||
self.device = self.get_conf('device', 'cpu')
|
||||
self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1)
|
||||
self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
||||
self.logging_dir = self.get_conf('logging_dir', None)
|
||||
self.log_dir = self.get_conf('log_dir', None)
|
||||
|
||||
self.writer = None
|
||||
self.setup_tensorboard()
|
||||
@@ -43,9 +43,9 @@ class TrainJob(BaseJob):
|
||||
process.run()
|
||||
|
||||
def setup_tensorboard(self):
|
||||
if self.logging_dir:
|
||||
if self.log_dir:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
self.writer = SummaryWriter(
|
||||
log_dir=self.logging_dir,
|
||||
log_dir=self.log_dir,
|
||||
filename_suffix=f"_{self.name}"
|
||||
)
|
||||
|
||||
@@ -45,21 +45,22 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.vae_path = self.get_conf('vae_path', required=True)
|
||||
self.datasets_objects = self.get_conf('datasets', required=True)
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
||||
self.batch_size = self.get_conf('batch_size', 1)
|
||||
self.resolution = self.get_conf('resolution', 256)
|
||||
self.learning_rate = self.get_conf('learning_rate', 1e-6)
|
||||
self.batch_size = self.get_conf('batch_size', 1, as_type=int)
|
||||
self.resolution = self.get_conf('resolution', 256, as_type=int)
|
||||
self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
|
||||
self.sample_every = self.get_conf('sample_every', None)
|
||||
self.epochs = self.get_conf('epochs', None)
|
||||
self.max_steps = self.get_conf('max_steps', None)
|
||||
self.optimizer_type = self.get_conf('optimizer', 'adam')
|
||||
self.epochs = self.get_conf('epochs', None, as_type=int)
|
||||
self.max_steps = self.get_conf('max_steps', None, as_type=int)
|
||||
self.save_every = self.get_conf('save_every', None)
|
||||
self.dtype = self.get_conf('dtype', 'float32')
|
||||
self.sample_sources = self.get_conf('sample_sources', None)
|
||||
self.log_every = self.get_conf('log_every', 100)
|
||||
self.style_weight = self.get_conf('style_weight', 0)
|
||||
self.content_weight = self.get_conf('content_weight', 0)
|
||||
self.kld_weight = self.get_conf('kld_weight', 0)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0)
|
||||
self.log_every = self.get_conf('log_every', 100, as_type=int)
|
||||
self.style_weight = self.get_conf('style_weight', 0, as_type=float)
|
||||
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
|
||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
|
||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||
self.writer = self.job.writer
|
||||
@@ -309,7 +310,12 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.vgg_19.eval()
|
||||
|
||||
# todo allow other optimizers
|
||||
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
|
||||
if self.optimizer_type == 'dadaptation':
|
||||
import dadaptation
|
||||
print("Using DAdaptAdam optimizer")
|
||||
optimizer = dadaptation.DAdaptAdam(params, lr=1)
|
||||
else:
|
||||
optimizer = torch.optim.Adam(params, lr=float(self.learning_rate))
|
||||
|
||||
# setup scheduler
|
||||
# todo allow other schedulers
|
||||
@@ -393,7 +399,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.tv_weight > 0:
|
||||
loss_string += f" tv: {tv_loss.item():.2e}"
|
||||
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
||||
self.progress_bar.set_description(f"E: {epoch}")
|
||||
self.progress_bar.update(1)
|
||||
|
||||
@@ -6,3 +6,6 @@ transformers
|
||||
lycoris_lora
|
||||
flatten_json
|
||||
accelerator
|
||||
pyyaml
|
||||
oyaml
|
||||
tensorboard
|
||||
2
run.py
2
run.py
@@ -27,7 +27,7 @@ def main():
|
||||
'config_file_list',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='Name of config file (eg: person_v1 for config/person_v1.json), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
|
||||
help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
|
||||
)
|
||||
|
||||
# flag to continue if failed job
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import json
|
||||
import oyaml as yaml
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.paths import TOOLKIT_ROOT
|
||||
|
||||
possible_extensions = ['.json', '.jsonc']
|
||||
possible_extensions = ['.json', '.jsonc', '.yaml', '.yml']
|
||||
|
||||
|
||||
def get_cwd_abs_path(path):
|
||||
@@ -49,8 +50,14 @@ def get_config(config_file_path):
|
||||
if not real_config_path:
|
||||
raise ValueError(f"Could not find config file {config_file_path}")
|
||||
|
||||
# load the config
|
||||
with open(real_config_path, 'r') as f:
|
||||
config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
# if we found it, check if it is a json or yaml file
|
||||
if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
|
||||
with open(real_config_path, 'r') as f:
|
||||
config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
|
||||
with open(real_config_path, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
|
||||
|
||||
return preprocess_config(config)
|
||||
|
||||
@@ -121,8 +121,9 @@ class Normalization(nn.Module):
|
||||
# normalize to min and max of 0 - 1
|
||||
in_min = torch.min(stacked_input)
|
||||
in_max = torch.max(stacked_input)
|
||||
norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
|
||||
return (norm_stacked_input - self.mean) / self.std
|
||||
# norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
|
||||
# return (norm_stacked_input - self.mean) / self.std
|
||||
return (stacked_input - self.mean) / self.std
|
||||
|
||||
|
||||
def get_style_model_and_losses(
|
||||
|
||||
Reference in New Issue
Block a user