WIP implementing training

This commit is contained in:
Jaret Burkett
2023-07-12 08:23:46 -06:00
parent 47d094e528
commit 57f14e5ef2
16 changed files with 1031 additions and 67 deletions

View File

@@ -10,17 +10,25 @@ a general understanding of python, pip, pytorch, and using virtual environments:
Linux: Linux:
```bash ```bash
git submodule update --init --recursive
pythion3 -m venv venv pythion3 -m venv venv
source venv/bin/activate source venv/bin/activate
pip install -r requirements.txt pip install -r requirements.txt
cd requirements/sd-scripts
pip install --no-deps -e .
cd ../..
``` ```
Windows: Windows:
```bash ```bash
git submodule update --init --recursive
pythion3 -m venv venv pythion3 -m venv venv
venv\Scripts\activate venv\Scripts\activate
pip install -r requirements.txt pip install -r requirements.txt
cd requirements/sd-scripts
pip install --no-deps -e .
cd ../..
``` ```
## Current Tools ## Current Tools

View File

@@ -5,7 +5,11 @@
"base_model": "/path/to/base/model", "base_model": "/path/to/base/model",
"training_folder": "/path/to/output/folder", "training_folder": "/path/to/output/folder",
"is_v2": false, "is_v2": false,
"device": "cpu", "device": "cuda",
"gradient_accumulation_steps": 1,
"mixed_precision": "fp16",
"logging_dir": "/path/to/tensorboard/log/folder",
"process": [ "process": [
{ {
"type": "fine_tune" "type": "fine_tune"

View File

@@ -1,3 +1,4 @@
import importlib
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
@@ -48,6 +49,8 @@ class BaseJob:
if len(self.config['process']) == 0: if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes') raise ValueError('config file is invalid. "config.process" must be a list of processes')
module = importlib.import_module('jobs.process')
# add the processes # add the processes
self.process = [] self.process = []
for i, process in enumerate(self.config['process']): for i, process in enumerate(self.config['process']):
@@ -56,7 +59,8 @@ class BaseJob:
# check if dict key is process type # check if dict key is process type
if process['type'] in process_dict: if process['type'] in process_dict:
self.process.append(process_dict[process['type']](i, self, process)) ProcessClass = getattr(module, process_dict[process['type']])
self.process.append(ProcessClass(i, self, process))
else: else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')

View File

@@ -1,19 +1,16 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from .BaseJob import BaseJob
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
from jobs.process import BaseExtractProcess from jobs.process import BaseExtractProcess
from jobs import BaseJob
from jobs.process import ExtractLoconProcess
process_dict = { process_dict = {
'locon': ExtractLoconProcess, 'locon': 'ExtractLoconProcess',
} }
class ExtractJob(BaseJob): class ExtractJob(BaseJob):
process: List[BaseExtractProcess]
def __init__(self, config: OrderedDict): def __init__(self, config: OrderedDict):
super().__init__(config) super().__init__(config)

View File

@@ -1,38 +1,85 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint # from jobs import BaseJob
from .BaseJob import BaseJob # from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from collections import OrderedDict # from collections import OrderedDict
from typing import List # from typing import List
# from jobs.process import BaseExtractProcess, TrainFineTuneProcess
from jobs.process import BaseExtractProcess, TrainFineTuneProcess # import gc
# import time
process_dict = { # import argparse
'fine_tine': TrainFineTuneProcess # import itertools
} # import math
# import os
# from multiprocessing import Value
class TrainJob(BaseJob): #
process: List[BaseExtractProcess] # from tqdm import tqdm
# import torch
def __init__(self, config: OrderedDict): # from accelerate.utils import set_seed
super().__init__(config) # from accelerate import Accelerator
self.base_model_path = self.get_conf('base_model', required=True) # import diffusers
self.base_model = None # from diffusers import DDPMScheduler
self.training_folder = self.get_conf('training_folder', required=True) #
self.is_v2 = self.get_conf('is_v2', False) # from toolkit.paths import SD_SCRIPTS_ROOT
self.device = self.get_conf('device', 'cpu') #
# import sys
# loads the processes from the config #
self.load_processes(process_dict) # sys.path.append(SD_SCRIPTS_ROOT)
#
def run(self): # import library.train_util as train_util
super().run() # import library.config_util as config_util
# load models # from library.config_util import (
print(f"Loading base model for training") # ConfigSanitizer,
print(f" - Loading base model: {self.base_model_path}") # BlueprintGenerator,
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) # )
# import toolkit.train_tools as train_tools
print("") # import library.custom_train_functions as custom_train_functions
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") # from library.custom_train_functions import (
# apply_snr_weight,
for process in self.process: # get_weighted_text_embeddings,
process.run() # prepare_scheduler_for_custom_training,
# pyramid_noise_like,
# apply_noise_offset,
# scale_v_prediction_loss_like_noise_prediction,
# )
#
# process_dict = {
# 'fine_tine': 'TrainFineTuneProcess'
# }
#
#
# class TrainJob(BaseJob):
# process: List[BaseExtractProcess]
#
# def __init__(self, config: OrderedDict):
# super().__init__(config)
# self.base_model_path = self.get_conf('base_model', required=True)
# self.base_model = None
# self.training_folder = self.get_conf('training_folder', required=True)
# self.is_v2 = self.get_conf('is_v2', False)
# self.device = self.get_conf('device', 'cpu')
# self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1)
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
# self.logging_dir = self.get_conf('logging_dir', None)
#
# # loads the processes from the config
# self.load_processes(process_dict)
#
# # setup accelerator
# self.accelerator = Accelerator(
# gradient_accumulation_steps=self.gradient_accumulation_steps,
# mixed_precision=self.mixed_precision,
# log_with=None if self.logging_dir is None else 'tensorboard',
# logging_dir=self.logging_dir,
# )
#
# def run(self):
# super().run()
# # load models
# print(f"Loading base model for training")
# print(f" - Loading base model: {self.base_model_path}")
# self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
#
# print("")
# print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
#
# for process in self.process:
# process.run()

View File

@@ -1,3 +1,2 @@
from .BaseJob import BaseJob from .BaseJob import BaseJob
from .ExtractJob import ExtractJob from .ExtractJob import ExtractJob
from .TrainJob import TrainJob

View File

@@ -3,13 +3,13 @@ from collections import OrderedDict
from safetensors.torch import save_file from safetensors.torch import save_file
from jobs import ExtractJob
from jobs.process.BaseProcess import BaseProcess from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from typing import ForwardRef
class BaseExtractProcess(BaseProcess): class BaseExtractProcess(BaseProcess):
job: ExtractJob
process_id: int process_id: int
config: OrderedDict config: OrderedDict
output_folder: str output_folder: str
@@ -19,7 +19,7 @@ class BaseExtractProcess(BaseProcess):
def __init__( def __init__(
self, self,
process_id: int, process_id: int,
job: ExtractJob, job,
config: OrderedDict config: OrderedDict
): ):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)

View File

@@ -1,8 +1,7 @@
import copy import copy
import json import json
from collections import OrderedDict from collections import OrderedDict
from typing import ForwardRef
from jobs import BaseJob
class BaseProcess: class BaseProcess:
@@ -11,7 +10,7 @@ class BaseProcess:
def __init__( def __init__(
self, self,
process_id: int, process_id: int,
job: BaseJob, job: 'BaseJob',
config: OrderedDict config: OrderedDict
): ):
self.process_id = process_id self.process_id = process_id
@@ -40,3 +39,5 @@ class BaseProcess:
def add_meta(self, additional_meta: OrderedDict): def add_meta(self, additional_meta: OrderedDict):
self.meta.update(additional_meta) self.meta.update(additional_meta)
from jobs import BaseJob

View File

@@ -1,17 +1,15 @@
from collections import OrderedDict from collections import OrderedDict
from jobs import TrainJob
from jobs.process.BaseProcess import BaseProcess from jobs.process.BaseProcess import BaseProcess
class BaseTrainProcess(BaseProcess): class BaseTrainProcess(BaseProcess):
job: TrainJob
process_id: int process_id: int
config: OrderedDict config: OrderedDict
def __init__( def __init__(
self, self,
process_id: int, process_id: int,
job: TrainJob, job,
config: OrderedDict config: OrderedDict
): ):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)

View File

@@ -1,7 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess from .BaseExtractProcess import BaseExtractProcess
from .. import ExtractJob
mode_dict = { mode_dict = {
'fixed': { 'fixed': {
@@ -28,7 +27,7 @@ mode_dict = {
class ExtractLoconProcess(BaseExtractProcess): class ExtractLoconProcess(BaseExtractProcess):
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict): def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed') self.mode = self.get_conf('mode', 'fixed')
self.use_sparse_bias = self.get_conf('use_sparse_bias', False) self.use_sparse_bias = self.get_conf('use_sparse_bias', False)

View File

@@ -4,3 +4,4 @@ diffusers
transformers transformers
lycoris_lora lycoris_lora
flatten_json flatten_json
accelerator

6
run.py
View File

@@ -1,9 +1,5 @@
import os import os
import sys import sys
from collections import OrderedDict
from jobs import BaseJob
sys.path.insert(0, os.getcwd()) sys.path.insert(0, os.getcwd())
import argparse import argparse
from toolkit.job import get_job from toolkit.job import get_job
@@ -49,6 +45,8 @@ def main():
jobs_completed = 0 jobs_completed = 0
jobs_failed = 0 jobs_failed = 0
print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
for config_file in config_file_list: for config_file in config_file_list:
try: try:
job = get_job(config_file) job = get_job(config_file)

547
scripts/train_dreambooth.py Normal file
View File

@@ -0,0 +1,547 @@
import gc
import time
import argparse
import itertools
import math
import os
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import custom_tools.train_tools as train_tools
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
# perlin_noise,
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
SD_SCRIPTS_ROOT = os.path.join(PROJECT_ROOT, "repositories", "sd-scripts")
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False)
cache_latents = args.cache_latents
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.no_token_padding:
train_dataset_group.disable_token_padding()
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# replace captions with names
if args.name_replace is not None:
print(f"Replacing captions [name] with '{args.name_replace}'")
train_dataset_group = train_tools.replace_filewords_in_dataset_group(
train_dataset_group, args
)
# acceleratorを準備する
print("prepare accelerator")
if args.gradient_accumulation_steps > 1:
print(
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
)
print(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
# verify load/save model formats
if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None
else:
src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors
else:
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder:
print("Text Encoder is not trained.")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
if train_text_encoder:
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
if args.sample_first or args.sample_only:
# Do initial sample before starting training
train_tools.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer,
text_encoder, unet, force_sample=True)
if args.sample_only:
return
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
with accelerator.accumulate(unet):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents
if args.train_noise_seed is not None:
torch.manual_seed(args.train_noise_seed)
torch.cuda.manual_seed(args.train_noise_seed)
# make same seed for each item in the batch by stacking them
single_noise = torch.randn_like(latents[0])
noise = torch.stack([single_noise for _ in range(b_size)])
noise = noise.to(latents.device)
elif args.seed_lock:
noise = train_tools.get_noise_from_latents(latents)
else:
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
# elif args.perlin_noise:
# noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if train_text_encoder:
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
params_to_clip = unet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
# checking for saving is in util
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--no_token_padding",
action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作",
)
parser.add_argument(
"--stop_text_encoder_training",
type=int,
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
parser.add_argument(
"--sample_first",
action="store_true",
help="Sample first interval before training",
default=False
)
parser.add_argument(
"--name_replace",
type=str,
help="Replaces [name] in prompts. Used is sampling, training, and regs",
default=None
)
parser.add_argument(
"--train_noise_seed",
type=int,
help="Use custom seed for training noise",
default=None
)
parser.add_argument(
"--sample_only",
action="store_true",
help="Only generate samples. Used for generating training data with specific seeds to alter during training",
default=False
)
parser.add_argument(
"--seed_lock",
action="store_true",
help="Locks the seed to the latent images so the same latent will always have the same noise",
default=False
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,8 +1,7 @@
from jobs import BaseJob
from toolkit.config import get_config from toolkit.config import get_config
def get_job(config_path) -> BaseJob: def get_job(config_path):
config = get_config(config_path) config = get_config(config_path)
if not config['job']: if not config['job']:
raise ValueError('config file is invalid. Missing "job" key') raise ValueError('config file is invalid. Missing "job" key')
@@ -11,8 +10,8 @@ def get_job(config_path) -> BaseJob:
if job == 'extract': if job == 'extract':
from jobs import ExtractJob from jobs import ExtractJob
return ExtractJob(config) return ExtractJob(config)
elif job == 'train': # elif job == 'train':
from jobs import TrainJob # from jobs import TrainJob
return TrainJob(config) # return TrainJob(config)
else: else:
raise ValueError(f'Unknown job type {job}') raise ValueError(f'Unknown job type {job}')

View File

@@ -2,3 +2,4 @@ import os
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")

361
toolkit/train_tools.py Normal file
View File

@@ -0,0 +1,361 @@
import argparse
import json
import os
import time
from diffusers import (
StableDiffusionPipeline,
DDPMScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import torch
import re
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
def replace_filewords_prompt(prompt, args: argparse.Namespace):
# if name_replace attr in args (may not be)
if hasattr(args, "name_replace") and args.name_replace is not None:
# replace [name] to args.name_replace
prompt = prompt.replace("[name]", args.name_replace)
if hasattr(args, "prepend") and args.prepend is not None:
# prepend to every item in prompt file
prompt = args.prepend + ' ' + prompt
if hasattr(args, "append") and args.append is not None:
# append to every item in prompt file
prompt = prompt + ' ' + args.append
return prompt
def replace_filewords_in_dataset_group(dataset_group, args: argparse.Namespace):
# if name_replace attr in args (may not be)
if hasattr(args, "name_replace") and args.name_replace is not None:
if not len(dataset_group.image_data) > 0:
# throw error
raise ValueError("dataset_group.image_data is empty")
for key in dataset_group.image_data:
dataset_group.image_data[key].caption = dataset_group.image_data[key].caption.replace(
"[name]", args.name_replace)
return dataset_group
def get_seeds_from_latents(latents):
# latents shape = (batch_size, 4, height, width)
# for speed we only use 8x8 slice of the first channel
seeds = []
# split batch up
for i in range(latents.shape[0]):
# use only first channel, multiply by 255 and convert to int
tensor = latents[i, 0, :, :] * 255.0 # shape = (height, width)
# slice 8x8
tensor = tensor[:8, :8]
# clip to 0-255
tensor = torch.clamp(tensor, 0, 255)
# convert to 8bit int
tensor = tensor.to(torch.uint8)
# convert to bytes
tensor_bytes = tensor.cpu().numpy().tobytes()
# hash
hash_object = hashlib.sha256(tensor_bytes)
# get hex
hex_dig = hash_object.hexdigest()
# convert to int
seed = int(hex_dig, 16) % (2 ** 32)
# append
seeds.append(seed)
return seeds
def get_noise_from_latents(latents):
seed_list = get_seeds_from_latents(latents)
noise = []
for seed in seed_list:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
noise.append(torch.randn_like(latents[0]))
return torch.stack(noise)
# mix 0 is completely noise mean, mix 1 is completely target mean
def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None):
dim = dim or (1, 2, 3)
# reduce mean of noise on dim 2, 3, keeping 0 and 1 intact
noise_mean = noise.mean(dim=dim, keepdim=True)
target_mean = target.mean(dim=dim, keepdim=True)
new_noise_mean = mix * target_mean + (1 - mix) * noise_mean
noise = noise - noise_mean + new_noise_mean
return noise
def sample_images(
accelerator,
args: argparse.Namespace,
epoch,
steps,
device,
vae,
tokenizer,
text_encoder,
unet,
prompt_replacement=None,
force_sample=False
):
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
"""
if not force_sample:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
is_sample_only = args.sample_only
is_generating_only = hasattr(args, "is_generating_only") and args.is_generating_only
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
org_vae_device = vae.device # CPUにいるはず
vae.to(device)
# read prompts
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
# prompts = f.readlines()
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
# schedulerを用意する
sched_init_args = {}
if args.sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif args.sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sample_sampler
elif args.sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif args.sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if args.v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
scheduler.config.clip_sample = True
pipeline = StableDiffusionLongPromptWeightingPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
clip_skip=args.clip_skip,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
pipeline.to(device)
if is_generating_only:
save_dir = args.output_dir
else:
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
with torch.no_grad():
with accelerator.autocast():
for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
prompt = prompt.get("prompt")
prompt = replace_filewords_prompt(prompt, args)
negative_prompt = replace_filewords_prompt(negative_prompt, args)
else:
prompt = replace_filewords_prompt(prompt, args)
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue
# subset of gen_img_diffusers
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
).images[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
if is_generating_only:
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
)
else:
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{i:04d}{seed_suffix}.png"
)
if is_sample_only:
# make prompt txt file
img_path_no_ext = os.path.join(save_dir, img_filename[:-4])
with open(img_path_no_ext + ".txt", "w") as f:
# put prompt in txt file
f.write(prompt)
# close file
f.close()
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)