diff --git a/README.md b/README.md index 4925f238..4fabac63 100644 --- a/README.md +++ b/README.md @@ -104,3 +104,8 @@ Just went in and out. It is much worse on smaller faces than shown here. +## TODO +- [ ] Add proper regs on sliders +- [ ] Add SDXL support (base model only for now) +- [ ] Add plain erasing +- [ ] Make Textual inversion network trainer (network that spits out TI embeddings) \ No newline at end of file diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 3453d7d9..339004c9 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -7,6 +7,7 @@ from typing import List, Literal from toolkit.kohya_model_util import load_vae from toolkit.lora_special import LoRASpecialNetwork +from toolkit.optimizer import get_optimizer from toolkit.paths import REPOS_ROOT import sys @@ -41,6 +42,7 @@ def flush(): UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + class StableDiffusion: def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler): self.vae = vae @@ -98,6 +100,7 @@ class TrainConfig: self.train_unet = kwargs.get('train_unet', True) self.train_text_encoder = kwargs.get('train_text_encoder', True) self.noise_offset = kwargs.get('noise_offset', 0.0) + self.optimizer_params = kwargs.get('optimizer_params', {}) class ModelConfig: @@ -377,17 +380,14 @@ class TrainSliderProcess(BaseTrainProcess): self.network.prepare_grad_etc(text_encoder, unet) - optimizer_type = self.train_config.optimizer.lower() - # we call it something different than leco - if optimizer_type == "dadaptation": - optimizer_type = "dadaptadam" - optimizer_module = train_util.get_optimizer(optimizer_type) - optimizer = optimizer_module( - self.network.prepare_optimizer_params( - self.train_config.lr, self.train_config.lr, self.train_config.lr - ), - lr=self.train_config.lr + params = self.network.prepare_optimizer_params( + text_encoder_lr=self.train_config.lr, + unet_lr=self.train_config.lr, + default_lr=self.train_config.lr ) + optimizer_type = self.train_config.optimizer.lower() + optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr, + optimizer_params=self.train_config.optimizer_params) lr_scheduler = train_util.get_lr_scheduler( self.train_config.lr_scheduler, optimizer, diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 1ec467f1..c62f4ef9 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -50,7 +50,8 @@ class Critic: lambda_gp=10, start_step=0, warmup_steps=1000, - process=None + process=None, + optimizer_params=None, ): self.learning_rate = learning_rate self.device = device @@ -65,6 +66,10 @@ class Critic: self.warmup_steps = warmup_steps self.start_step = start_step self.lambda_gp = lambda_gp + + if optimizer_params is None: + optimizer_params = {} + self.optimizer_params = optimizer_params self.print = self.process.print print(f" Critic config: {self.__dict__}") @@ -75,7 +80,8 @@ class Critic: self.model.train() self.model.requires_grad_(True) params = self.model.parameters() - self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate) + self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) self.scheduler = torch.optim.lr_scheduler.ConstantLR( self.optimizer, total_iters=self.process.max_steps * self.num_critic_per_gen, @@ -196,6 +202,7 @@ class TrainVAEProcess(BaseTrainProcess): self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) + self.optimizer_params = self.get_conf('optimizer_params', {}) self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.torch_dtype = get_torch_dtype(self.dtype) @@ -342,7 +349,8 @@ class TrainVAEProcess(BaseTrainProcess): def get_pattern_loss(self, pred, target): if self._pattern_loss is None: - self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device, dtype=self.torch_dtype) + self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device, + dtype=self.torch_dtype) loss = torch.mean(self._pattern_loss(pred, target)) return loss @@ -504,7 +512,8 @@ class TrainVAEProcess(BaseTrainProcess): if self.use_critic: self.critic.setup() - optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate) + optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) # setup scheduler # todo allow other schedulers diff --git a/jobs/process/models/vgg19_critic.py b/jobs/process/models/vgg19_critic.py index 808d63b8..6fd2f9c2 100644 --- a/jobs/process/models/vgg19_critic.py +++ b/jobs/process/models/vgg19_critic.py @@ -21,11 +21,11 @@ class Vgg19Critic(nn.Module): super(Vgg19Critic, self).__init__() self.main = nn.Sequential( # input (bs, 512, 32, 32) - nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2), # (bs, 512, 16, 16) - nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2), # (bs, 512, 8, 8) - nn.Conv2d(512, 1, kernel_size=3, stride=2, padding=1), + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), # (bs, 1, 4, 4) MeanReduce(), # (bs, 1, 1, 1) nn.Flatten(), # (bs, 1) diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index bba10cf1..a9a9b6e7 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -4,18 +4,45 @@ import torch def get_optimizer( params, optimizer_type='adam', - learning_rate=1e-6 + learning_rate=1e-6, + optimizer_params=None ): + if optimizer_params is None: + optimizer_params = {} lower_type = optimizer_type.lower() - if lower_type == 'dadaptation': + if lower_type.startswith("dadaptation"): # dadaptation optimizer does not use standard learning rate. 1 is the default value import dadaptation print("Using DAdaptAdam optimizer") - optimizer = dadaptation.DAdaptAdam(params, lr=1.0) + use_lr = learning_rate + if use_lr < 0.1: + # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 + use_lr = 1.0 + if lower_type.endswith('lion'): + optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params) + elif lower_type.endswith('adam'): + optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params) + elif lower_type == 'dadaptation': + # backwards compatibility + optimizer = dadaptation.DAdaptAdam(params, lr=use_lr, **optimizer_params) + # warn user that dadaptation is deprecated + print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.") + elif lower_type.endswith("8bit"): + import bitsandbytes + + if lower_type == "adam8bit": + return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params) + elif lower_type == "lion8bit": + return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) + else: + raise ValueError(f'Unknown optimizer type {optimizer_type}') elif lower_type == 'adam': - optimizer = torch.optim.Adam(params, lr=float(learning_rate)) + optimizer = torch.optim.Adam(params, lr=float(learning_rate), **optimizer_params) elif lower_type == 'adamw': - optimizer = torch.optim.AdamW(params, lr=float(learning_rate)) + optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params) + elif lower_type == 'lion': + from lion_pytorch import Lion + return Lion(params, lr=learning_rate, **optimizer_params) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') return optimizer