mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Added better optimizer chooised and param support
This commit is contained in:
@@ -104,3 +104,8 @@ Just went in and out. It is much worse on smaller faces than shown here.
|
|||||||
|
|
||||||
<img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">
|
<img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
- [ ] Add proper regs on sliders
|
||||||
|
- [ ] Add SDXL support (base model only for now)
|
||||||
|
- [ ] Add plain erasing
|
||||||
|
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
|
||||||
@@ -7,6 +7,7 @@ from typing import List, Literal
|
|||||||
|
|
||||||
from toolkit.kohya_model_util import load_vae
|
from toolkit.kohya_model_util import load_vae
|
||||||
from toolkit.lora_special import LoRASpecialNetwork
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
|
from toolkit.optimizer import get_optimizer
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -41,6 +42,7 @@ def flush():
|
|||||||
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
||||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion:
|
class StableDiffusion:
|
||||||
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
|
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
@@ -98,6 +100,7 @@ class TrainConfig:
|
|||||||
self.train_unet = kwargs.get('train_unet', True)
|
self.train_unet = kwargs.get('train_unet', True)
|
||||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||||
|
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@@ -377,17 +380,14 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
self.network.prepare_grad_etc(text_encoder, unet)
|
self.network.prepare_grad_etc(text_encoder, unet)
|
||||||
|
|
||||||
optimizer_type = self.train_config.optimizer.lower()
|
params = self.network.prepare_optimizer_params(
|
||||||
# we call it something different than leco
|
text_encoder_lr=self.train_config.lr,
|
||||||
if optimizer_type == "dadaptation":
|
unet_lr=self.train_config.lr,
|
||||||
optimizer_type = "dadaptadam"
|
default_lr=self.train_config.lr
|
||||||
optimizer_module = train_util.get_optimizer(optimizer_type)
|
|
||||||
optimizer = optimizer_module(
|
|
||||||
self.network.prepare_optimizer_params(
|
|
||||||
self.train_config.lr, self.train_config.lr, self.train_config.lr
|
|
||||||
),
|
|
||||||
lr=self.train_config.lr
|
|
||||||
)
|
)
|
||||||
|
optimizer_type = self.train_config.optimizer.lower()
|
||||||
|
optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr,
|
||||||
|
optimizer_params=self.train_config.optimizer_params)
|
||||||
lr_scheduler = train_util.get_lr_scheduler(
|
lr_scheduler = train_util.get_lr_scheduler(
|
||||||
self.train_config.lr_scheduler,
|
self.train_config.lr_scheduler,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|||||||
@@ -50,7 +50,8 @@ class Critic:
|
|||||||
lambda_gp=10,
|
lambda_gp=10,
|
||||||
start_step=0,
|
start_step=0,
|
||||||
warmup_steps=1000,
|
warmup_steps=1000,
|
||||||
process=None
|
process=None,
|
||||||
|
optimizer_params=None,
|
||||||
):
|
):
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -65,6 +66,10 @@ class Critic:
|
|||||||
self.warmup_steps = warmup_steps
|
self.warmup_steps = warmup_steps
|
||||||
self.start_step = start_step
|
self.start_step = start_step
|
||||||
self.lambda_gp = lambda_gp
|
self.lambda_gp = lambda_gp
|
||||||
|
|
||||||
|
if optimizer_params is None:
|
||||||
|
optimizer_params = {}
|
||||||
|
self.optimizer_params = optimizer_params
|
||||||
self.print = self.process.print
|
self.print = self.process.print
|
||||||
print(f" Critic config: {self.__dict__}")
|
print(f" Critic config: {self.__dict__}")
|
||||||
|
|
||||||
@@ -75,7 +80,8 @@ class Critic:
|
|||||||
self.model.train()
|
self.model.train()
|
||||||
self.model.requires_grad_(True)
|
self.model.requires_grad_(True)
|
||||||
params = self.model.parameters()
|
params = self.model.parameters()
|
||||||
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
|
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||||
|
optimizer_params=self.optimizer_params)
|
||||||
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||||
@@ -196,6 +202,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||||
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
|
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
|
||||||
|
self.optimizer_params = self.get_conf('optimizer_params', {})
|
||||||
|
|
||||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||||
@@ -342,7 +349,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
def get_pattern_loss(self, pred, target):
|
def get_pattern_loss(self, pred, target):
|
||||||
if self._pattern_loss is None:
|
if self._pattern_loss is None:
|
||||||
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device, dtype=self.torch_dtype)
|
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device,
|
||||||
|
dtype=self.torch_dtype)
|
||||||
loss = torch.mean(self._pattern_loss(pred, target))
|
loss = torch.mean(self._pattern_loss(pred, target))
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -504,7 +512,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
if self.use_critic:
|
if self.use_critic:
|
||||||
self.critic.setup()
|
self.critic.setup()
|
||||||
|
|
||||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
|
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||||
|
optimizer_params=self.optimizer_params)
|
||||||
|
|
||||||
# setup scheduler
|
# setup scheduler
|
||||||
# todo allow other schedulers
|
# todo allow other schedulers
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ class Vgg19Critic(nn.Module):
|
|||||||
super(Vgg19Critic, self).__init__()
|
super(Vgg19Critic, self).__init__()
|
||||||
self.main = nn.Sequential(
|
self.main = nn.Sequential(
|
||||||
# input (bs, 512, 32, 32)
|
# input (bs, 512, 32, 32)
|
||||||
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
|
||||||
nn.LeakyReLU(0.2), # (bs, 512, 16, 16)
|
nn.LeakyReLU(0.2), # (bs, 512, 16, 16)
|
||||||
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
|
||||||
nn.LeakyReLU(0.2), # (bs, 512, 8, 8)
|
nn.LeakyReLU(0.2), # (bs, 512, 8, 8)
|
||||||
nn.Conv2d(512, 1, kernel_size=3, stride=2, padding=1),
|
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
|
||||||
# (bs, 1, 4, 4)
|
# (bs, 1, 4, 4)
|
||||||
MeanReduce(), # (bs, 1, 1, 1)
|
MeanReduce(), # (bs, 1, 1, 1)
|
||||||
nn.Flatten(), # (bs, 1)
|
nn.Flatten(), # (bs, 1)
|
||||||
|
|||||||
@@ -4,18 +4,45 @@ import torch
|
|||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
optimizer_type='adam',
|
optimizer_type='adam',
|
||||||
learning_rate=1e-6
|
learning_rate=1e-6,
|
||||||
|
optimizer_params=None
|
||||||
):
|
):
|
||||||
|
if optimizer_params is None:
|
||||||
|
optimizer_params = {}
|
||||||
lower_type = optimizer_type.lower()
|
lower_type = optimizer_type.lower()
|
||||||
if lower_type == 'dadaptation':
|
if lower_type.startswith("dadaptation"):
|
||||||
# dadaptation optimizer does not use standard learning rate. 1 is the default value
|
# dadaptation optimizer does not use standard learning rate. 1 is the default value
|
||||||
import dadaptation
|
import dadaptation
|
||||||
print("Using DAdaptAdam optimizer")
|
print("Using DAdaptAdam optimizer")
|
||||||
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
|
use_lr = learning_rate
|
||||||
|
if use_lr < 0.1:
|
||||||
|
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
|
||||||
|
use_lr = 1.0
|
||||||
|
if lower_type.endswith('lion'):
|
||||||
|
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
|
||||||
|
elif lower_type.endswith('adam'):
|
||||||
|
optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params)
|
||||||
|
elif lower_type == 'dadaptation':
|
||||||
|
# backwards compatibility
|
||||||
|
optimizer = dadaptation.DAdaptAdam(params, lr=use_lr, **optimizer_params)
|
||||||
|
# warn user that dadaptation is deprecated
|
||||||
|
print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.")
|
||||||
|
elif lower_type.endswith("8bit"):
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
if lower_type == "adam8bit":
|
||||||
|
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
|
||||||
|
elif lower_type == "lion8bit":
|
||||||
|
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||||
elif lower_type == 'adam':
|
elif lower_type == 'adam':
|
||||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
|
optimizer = torch.optim.Adam(params, lr=float(learning_rate), **optimizer_params)
|
||||||
elif lower_type == 'adamw':
|
elif lower_type == 'adamw':
|
||||||
optimizer = torch.optim.AdamW(params, lr=float(learning_rate))
|
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params)
|
||||||
|
elif lower_type == 'lion':
|
||||||
|
from lion_pytorch import Lion
|
||||||
|
return Lion(params, lr=learning_rate, **optimizer_params)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|||||||
Reference in New Issue
Block a user