Added better optimizer chooised and param support

This commit is contained in:
Jaret Burkett
2023-07-24 09:21:58 -06:00
parent 9a2819900c
commit e6fb0229bf
5 changed files with 63 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ from typing import List, Literal
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
from toolkit.paths import REPOS_ROOT
import sys
@@ -41,6 +42,7 @@ def flush():
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
class StableDiffusion:
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
self.vae = vae
@@ -98,6 +100,7 @@ class TrainConfig:
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {})
class ModelConfig:
@@ -377,17 +380,14 @@ class TrainSliderProcess(BaseTrainProcess):
self.network.prepare_grad_etc(text_encoder, unet)
optimizer_type = self.train_config.optimizer.lower()
# we call it something different than leco
if optimizer_type == "dadaptation":
optimizer_type = "dadaptadam"
optimizer_module = train_util.get_optimizer(optimizer_type)
optimizer = optimizer_module(
self.network.prepare_optimizer_params(
self.train_config.lr, self.train_config.lr, self.train_config.lr
),
lr=self.train_config.lr
params = self.network.prepare_optimizer_params(
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
optimizer_type = self.train_config.optimizer.lower()
optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr,
optimizer_params=self.train_config.optimizer_params)
lr_scheduler = train_util.get_lr_scheduler(
self.train_config.lr_scheduler,
optimizer,