mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
WIP implementing training
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
from jobs import BaseJob
|
||||
from toolkit.config import get_config
|
||||
|
||||
|
||||
def get_job(config_path) -> BaseJob:
|
||||
def get_job(config_path):
|
||||
config = get_config(config_path)
|
||||
if not config['job']:
|
||||
raise ValueError('config file is invalid. Missing "job" key')
|
||||
@@ -11,8 +10,8 @@ def get_job(config_path) -> BaseJob:
|
||||
if job == 'extract':
|
||||
from jobs import ExtractJob
|
||||
return ExtractJob(config)
|
||||
elif job == 'train':
|
||||
from jobs import TrainJob
|
||||
return TrainJob(config)
|
||||
# elif job == 'train':
|
||||
# from jobs import TrainJob
|
||||
# return TrainJob(config)
|
||||
else:
|
||||
raise ValueError(f'Unknown job type {job}')
|
||||
|
||||
@@ -2,3 +2,4 @@ import os
|
||||
|
||||
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||
|
||||
361
toolkit/train_tools.py
Normal file
361
toolkit/train_tools.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
DDPMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
)
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import torch
|
||||
import re
|
||||
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
|
||||
def replace_filewords_prompt(prompt, args: argparse.Namespace):
|
||||
# if name_replace attr in args (may not be)
|
||||
if hasattr(args, "name_replace") and args.name_replace is not None:
|
||||
# replace [name] to args.name_replace
|
||||
prompt = prompt.replace("[name]", args.name_replace)
|
||||
if hasattr(args, "prepend") and args.prepend is not None:
|
||||
# prepend to every item in prompt file
|
||||
prompt = args.prepend + ' ' + prompt
|
||||
if hasattr(args, "append") and args.append is not None:
|
||||
# append to every item in prompt file
|
||||
prompt = prompt + ' ' + args.append
|
||||
return prompt
|
||||
|
||||
|
||||
def replace_filewords_in_dataset_group(dataset_group, args: argparse.Namespace):
|
||||
# if name_replace attr in args (may not be)
|
||||
if hasattr(args, "name_replace") and args.name_replace is not None:
|
||||
if not len(dataset_group.image_data) > 0:
|
||||
# throw error
|
||||
raise ValueError("dataset_group.image_data is empty")
|
||||
for key in dataset_group.image_data:
|
||||
dataset_group.image_data[key].caption = dataset_group.image_data[key].caption.replace(
|
||||
"[name]", args.name_replace)
|
||||
|
||||
return dataset_group
|
||||
|
||||
|
||||
def get_seeds_from_latents(latents):
|
||||
# latents shape = (batch_size, 4, height, width)
|
||||
# for speed we only use 8x8 slice of the first channel
|
||||
seeds = []
|
||||
|
||||
# split batch up
|
||||
for i in range(latents.shape[0]):
|
||||
# use only first channel, multiply by 255 and convert to int
|
||||
tensor = latents[i, 0, :, :] * 255.0 # shape = (height, width)
|
||||
# slice 8x8
|
||||
tensor = tensor[:8, :8]
|
||||
# clip to 0-255
|
||||
tensor = torch.clamp(tensor, 0, 255)
|
||||
# convert to 8bit int
|
||||
tensor = tensor.to(torch.uint8)
|
||||
# convert to bytes
|
||||
tensor_bytes = tensor.cpu().numpy().tobytes()
|
||||
# hash
|
||||
hash_object = hashlib.sha256(tensor_bytes)
|
||||
# get hex
|
||||
hex_dig = hash_object.hexdigest()
|
||||
# convert to int
|
||||
seed = int(hex_dig, 16) % (2 ** 32)
|
||||
# append
|
||||
seeds.append(seed)
|
||||
return seeds
|
||||
|
||||
|
||||
def get_noise_from_latents(latents):
|
||||
seed_list = get_seeds_from_latents(latents)
|
||||
noise = []
|
||||
for seed in seed_list:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
noise.append(torch.randn_like(latents[0]))
|
||||
return torch.stack(noise)
|
||||
|
||||
|
||||
# mix 0 is completely noise mean, mix 1 is completely target mean
|
||||
|
||||
def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None):
|
||||
dim = dim or (1, 2, 3)
|
||||
# reduce mean of noise on dim 2, 3, keeping 0 and 1 intact
|
||||
noise_mean = noise.mean(dim=dim, keepdim=True)
|
||||
target_mean = target.mean(dim=dim, keepdim=True)
|
||||
|
||||
new_noise_mean = mix * target_mean + (1 - mix) * noise_mean
|
||||
|
||||
noise = noise - noise_mean + new_noise_mean
|
||||
return noise
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
prompt_replacement=None,
|
||||
force_sample=False
|
||||
):
|
||||
"""
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
"""
|
||||
if not force_sample:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
is_sample_only = args.sample_only
|
||||
is_generating_only = hasattr(args, "is_generating_only") and args.is_generating_only
|
||||
|
||||
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(device)
|
||||
|
||||
# read prompts
|
||||
|
||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
# prompts = f.readlines()
|
||||
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# schedulerを用意する
|
||||
sched_init_args = {}
|
||||
if args.sample_sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
||||
scheduler_cls = DDPMScheduler
|
||||
elif args.sample_sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = args.sample_sampler
|
||||
elif args.sample_sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif args.sample_sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
else:
|
||||
scheduler_cls = DDIMScheduler
|
||||
|
||||
if args.v_parameterization:
|
||||
sched_init_args["prediction_type"] = "v_prediction"
|
||||
|
||||
scheduler = scheduler_cls(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
**sched_init_args,
|
||||
)
|
||||
|
||||
# clip_sample=Trueにする
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
pipeline = StableDiffusionLongPromptWeightingPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
clip_skip=args.clip_skip,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
pipeline.to(device)
|
||||
|
||||
if is_generating_only:
|
||||
save_dir = args.output_dir
|
||||
else:
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
|
||||
with torch.no_grad():
|
||||
with accelerator.autocast():
|
||||
for i, prompt in enumerate(prompts):
|
||||
if not accelerator.is_main_process:
|
||||
continue
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
negative_prompt = prompt.get("negative_prompt")
|
||||
sample_steps = prompt.get("sample_steps", 30)
|
||||
width = prompt.get("width", 512)
|
||||
height = prompt.get("height", 512)
|
||||
scale = prompt.get("scale", 7.5)
|
||||
seed = prompt.get("seed")
|
||||
prompt = prompt.get("prompt")
|
||||
|
||||
prompt = replace_filewords_prompt(prompt, args)
|
||||
negative_prompt = replace_filewords_prompt(negative_prompt, args)
|
||||
else:
|
||||
prompt = replace_filewords_prompt(prompt, args)
|
||||
# prompt = prompt.strip()
|
||||
# if len(prompt) == 0 or prompt[0] == "#":
|
||||
# continue
|
||||
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = prompt.split(" --")
|
||||
prompt = prompt_args[0]
|
||||
negative_prompt = None
|
||||
sample_steps = 30
|
||||
width = height = 512
|
||||
scale = 7.5
|
||||
seed = None
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
seed = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
print(f"prompt: {prompt}")
|
||||
print(f"negative_prompt: {negative_prompt}")
|
||||
print(f"height: {height}")
|
||||
print(f"width: {width}")
|
||||
print(f"sample_steps: {sample_steps}")
|
||||
print(f"scale: {scale}")
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_steps,
|
||||
guidance_scale=scale,
|
||||
negative_prompt=negative_prompt,
|
||||
).images[0]
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
|
||||
if is_generating_only:
|
||||
img_filename = (
|
||||
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
||||
)
|
||||
else:
|
||||
img_filename = (
|
||||
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{i:04d}{seed_suffix}.png"
|
||||
)
|
||||
if is_sample_only:
|
||||
# make prompt txt file
|
||||
img_path_no_ext = os.path.join(save_dir, img_filename[:-4])
|
||||
with open(img_path_no_ext + ".txt", "w") as f:
|
||||
# put prompt in txt file
|
||||
f.write(prompt)
|
||||
# close file
|
||||
f.close()
|
||||
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
Reference in New Issue
Block a user