mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added better optimizer chooised and param support
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import List, Literal
|
||||
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
@@ -41,6 +42,7 @@ def flush():
|
||||
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
||||
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self, vae, tokenizer, text_encoder, unet, noise_scheduler):
|
||||
self.vae = vae
|
||||
@@ -98,6 +100,7 @@ class TrainConfig:
|
||||
self.train_unet = kwargs.get('train_unet', True)
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -377,17 +380,14 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
|
||||
self.network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
optimizer_type = self.train_config.optimizer.lower()
|
||||
# we call it something different than leco
|
||||
if optimizer_type == "dadaptation":
|
||||
optimizer_type = "dadaptadam"
|
||||
optimizer_module = train_util.get_optimizer(optimizer_type)
|
||||
optimizer = optimizer_module(
|
||||
self.network.prepare_optimizer_params(
|
||||
self.train_config.lr, self.train_config.lr, self.train_config.lr
|
||||
),
|
||||
lr=self.train_config.lr
|
||||
params = self.network.prepare_optimizer_params(
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
optimizer_type = self.train_config.optimizer.lower()
|
||||
optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr,
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
lr_scheduler = train_util.get_lr_scheduler(
|
||||
self.train_config.lr_scheduler,
|
||||
optimizer,
|
||||
|
||||
Reference in New Issue
Block a user