diff --git a/extensions_built_in/concept_replacer/ConceptReplacer.py b/extensions_built_in/concept_replacer/ConceptReplacer.py new file mode 100644 index 00000000..b451210d --- /dev/null +++ b/extensions_built_in/concept_replacer/ConceptReplacer.py @@ -0,0 +1,159 @@ +import random +from collections import OrderedDict +from torch.utils.data import DataLoader +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +import torch +from jobs.process import BaseSDTrainProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class ConceptReplacementConfig: + def __init__(self, **kwargs): + self.concept: str = kwargs.get('concept', '') + self.replacement: str = kwargs.get('replacement', '') + + +class ConceptReplacer(BaseSDTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + replacement_list = self.config.get('replacements', []) + self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list] + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + + # textual inversion + if self.embedding is not None: + # keep original embeddings as reference + self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone() + # set text encoder to train. Not sure if this is necessary but diffusers example did it + self.sd.text_encoder.train() + + def hook_train_loop(self, batch): + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + network_weight_list = batch.get_network_weight_list() + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + batch_replacement_list = [] + # get a random replacement for each prompt + for prompt in conditioned_prompts: + replacement = random.choice(self.replacement_list) + batch_replacement_list.append(replacement) + + # build out prompts + concept_prompts = [] + replacement_prompts = [] + for idx, replacement in enumerate(batch_replacement_list): + prompt = conditioned_prompts[idx] + + # insert shuffled concept at beginning and end of prompt + shuffled_concept = [x.strip() for x in replacement.concept.split(',')] + random.shuffle(shuffled_concept) + shuffled_concept = ', '.join(shuffled_concept) + concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}") + + # insert replacement at beginning and end of prompt + shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')] + random.shuffle(shuffled_replacement) + shuffled_replacement = ', '.join(shuffled_replacement) + replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}") + + # predict the replacement without network + conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype) + + replacement_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + + del conditional_embeds + replacement_pred = replacement_pred.detach() + + self.optimizer.zero_grad() + flush() + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding: + grad_on_text_encoder = True + + # set the weights + network.multiplier = network_weight_list + + # activate network if it exits + with network: + with torch.set_grad_enabled(grad_on_text_encoder): + # embed the prompts + conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype) + if not grad_on_text_encoder: + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + self.optimizer.zero_grad() + flush() + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # back propagate loss to free ram + loss.backward() + flush() + + # apply gradients + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + + if self.embedding is not None: + # Let's make sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool) + index_no_updates[ + min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False + with torch.no_grad(): + self.sd.text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = self.orig_embeds_params[index_no_updates] + + loss_dict = OrderedDict( + {'loss': loss.item()} + ) + # reset network multiplier + network.multiplier = 1.0 + + return loss_dict diff --git a/extensions_built_in/concept_replacer/__init__.py b/extensions_built_in/concept_replacer/__init__.py new file mode 100644 index 00000000..69dc7311 --- /dev/null +++ b/extensions_built_in/concept_replacer/__init__.py @@ -0,0 +1,26 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class ConceptReplacerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "concept_replacer" + + # name is the name of the extension for printing + name = "Concept Replacer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ConceptReplacer import ConceptReplacer + return ConceptReplacer + + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ConceptReplacerExtension, +] diff --git a/extensions_built_in/concept_replacer/config/train.example.yaml b/extensions_built_in/concept_replacer/config/train.example.yaml new file mode 100644 index 00000000..793d5d55 --- /dev/null +++ b/extensions_built_in/concept_replacer/config/train.example.yaml @@ -0,0 +1,91 @@ +--- +job: extension +config: + name: test_v1 + process: + - type: 'textual_inversion_trainer' + training_folder: "out/TI" + device: cuda:0 + # for tensorboard logging + log_dir: "out/.tensorboard" + embedding: + trigger: "your_trigger_here" + tokens: 12 + init_words: "man with short brown hair" + save_format: "safetensors" # 'safetensors' or 'pt' + save: + dtype: float16 # precision to save + save_every: 100 # save every this many steps + max_step_saves_to_keep: 5 # only affects step counts + datasets: + - folder_path: "/path/to/dataset" + caption_ext: "txt" + default_caption: "[trigger]" + buckets: true + resolution: 512 + train: + noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" + steps: 3000 + weight_jitter: 0.0 + lr: 5e-5 + train_unet: false + gradient_checkpointing: true + train_text_encoder: false + optimizer: "adamw" +# optimizer: "prodigy" + optimizer_params: + weight_decay: 1e-2 + lr_scheduler: "constant" + max_denoising_steps: 1000 + batch_size: 4 + dtype: bf16 + xformers: true + min_snr_gamma: 5.0 +# skip_first_sample: true + noise_offset: 0.0 # not needed for this + model: + # objective reality v2 + name_or_path: "https://civitai.com/models/128453?modelVersionId=142465" + is_v2: false # for v2 models + is_xl: false # for SDXL models + is_v_pred: false # for v-prediction models (most v2 models) + sample: + sampler: "ddpm" # must match train.noise_scheduler + sample_every: 100 # sample every this many steps + width: 512 + height: 512 + prompts: + - "photo of [trigger] laughing" + - "photo of [trigger] smiling" + - "[trigger] close up" + - "dark scene [trigger] frozen" + - "[trigger] nighttime" + - "a painting of [trigger]" + - "a drawing of [trigger]" + - "a cartoon of [trigger]" + - "[trigger] pixar style" + - "[trigger] costume" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 7 + sample_steps: 20 + network_multiplier: 1.0 + + logging: + log_every: 10 # log every this many steps + use_wandb: false # not supported yet + verbose: false + +# You can put any information you want here, and it will be saved in the model. +# The below is an example, but you can put your grocery list in it if you want. +# It is saved in the model so be aware of that. The software will include this +# plus some other information for you automatically +meta: + # [name] gets replaced with the name above + name: "[name]" +# version: '1.0' +# creator: +# name: Your Name +# email: your@gmail.com +# website: https://your.website diff --git a/scripts/convert_cog.py b/scripts/convert_cog.py new file mode 100644 index 00000000..ba4f6e73 --- /dev/null +++ b/scripts/convert_cog.py @@ -0,0 +1,128 @@ +import json +from collections import OrderedDict +import os +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +device = torch.device('cpu') + +# [diffusers] -> kohya +embedding_mapping = { + 'text_encoders_0': 'clip_l', + 'text_encoders_1': 'clip_g' +} + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps') +sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json') + +# load keymap +with open(sdxl_keymap_path, 'r') as f: + ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap'] + +# invert the item / key pairs +diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()} + + +def get_ldm_key(diffuser_key): + diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}" + diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight') + diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight') + diffuser_key = diffuser_key.replace('_alpha', '.alpha') + diffuser_key = diffuser_key.replace('_processor_to_', '_to_') + diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.') + if diffuser_key in diffusers_ldm_keymap: + return diffusers_ldm_keymap[diffuser_key] + else: + raise KeyError(f"Key {diffuser_key} not found in keymap") + + +def convert_cog(lora_path, embedding_path): + embedding_state_dict = OrderedDict() + lora_state_dict = OrderedDict() + + # # normal dict + # normal_dict = OrderedDict() + # example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors" + # with safe_open(example_path, framework="pt", device='cpu') as f: + # keys = list(f.keys()) + # for key in keys: + # normal_dict[key] = f.get_tensor(key) + + with safe_open(embedding_path, framework="pt", device='cpu') as f: + keys = list(f.keys()) + for key in keys: + new_key = embedding_mapping[key] + embedding_state_dict[new_key] = f.get_tensor(key) + + with safe_open(lora_path, framework="pt", device='cpu') as f: + keys = list(f.keys()) + lora_rank = None + + # get the lora dim first. Check first 3 linear layers just to be safe + for key in keys: + new_key = get_ldm_key(key) + tensor = f.get_tensor(key) + num_checked = 0 + if len(tensor.shape) == 2: + this_dim = min(tensor.shape) + if lora_rank is None: + lora_rank = this_dim + elif lora_rank != this_dim: + raise ValueError(f"lora rank is not consistent, got {tensor.shape}") + else: + num_checked += 1 + if num_checked >= 3: + break + + for key in keys: + new_key = get_ldm_key(key) + tensor = f.get_tensor(key) + if new_key.endswith('.lora_down.weight'): + alpha_key = new_key.replace('.lora_down.weight', '.alpha') + # diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims + # assume first smallest dim is the lora rank if shape is 2 + lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank + + lora_state_dict[new_key] = tensor + + return lora_state_dict, embedding_state_dict + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + 'lora_path', + type=str, + help='Path to lora file' + ) + parser.add_argument( + 'embedding_path', + type=str, + help='Path to embedding file' + ) + + parser.add_argument( + '--lora_output', + type=str, + default="lora_output", + ) + + parser.add_argument( + '--embedding_output', + type=str, + default="embedding_output", + ) + + args = parser.parse_args() + + lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path) + + # save them + save_file(lora_state_dict, args.lora_output) + save_file(embedding_state_dict, args.embedding_output) + print(f"Saved lora to {args.lora_output}") + print(f"Saved embedding to {args.embedding_output}") diff --git a/scripts/train_dreambooth.py b/scripts/train_dreambooth.py deleted file mode 100644 index 8442ddeb..00000000 --- a/scripts/train_dreambooth.py +++ /dev/null @@ -1,547 +0,0 @@ -import gc -import time -import argparse -import itertools -import math -import os -from multiprocessing import Value - -from tqdm import tqdm -import torch -from accelerate.utils import set_seed -import diffusers -from diffusers import DDPMScheduler - -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import custom_tools.train_tools as train_tools -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - get_weighted_text_embeddings, - prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, - scale_v_prediction_loss_like_noise_prediction, -) - -# perlin_noise, - -PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -SD_SCRIPTS_ROOT = os.path.join(PROJECT_ROOT, "repositories", "sd-scripts") - - -def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, False) - - cache_latents = args.cache_latents - - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - - if args.no_token_padding: - train_dataset_group.disable_token_padding() - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # replace captions with names - if args.name_replace is not None: - print(f"Replacing captions [name] with '{args.name_replace}'") - - train_dataset_group = train_tools.replace_filewords_in_dataset_group( - train_dataset_group, args - ) - - # acceleratorを準備する - print("prepare accelerator") - - if args.gradient_accumulation_steps > 1: - print( - f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" - ) - print( - f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" - ) - - accelerator, unwrap_model = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) - - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - accelerator.wait_for_everyone() - - # 学習を準備する:モデルを適切な状態にする - train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 - unet.requires_grad_(True) # 念のため追加 - text_encoder.requires_grad_(train_text_encoder) - if not train_text_encoder: - print("Text Encoder is not trained.") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - if train_text_encoder: - trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) - else: - trainable_params = unet.parameters() - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collater, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - if args.stop_text_encoder_training is None: - args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - - # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - - if not train_text_encoder: - text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - - if accelerator.is_main_process: - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name) - - if args.sample_first or args.sample_only: - # Do initial sample before starting training - train_tools.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, - text_encoder, unet, force_sample=True) - - if args.sample_only: - return - loss_list = [] - loss_total = 0.0 - for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - # 指定したステップ数までText Encoderを学習する:epoch最初の状態 - unet.train() - # train==True is required to enable gradient_checkpointing - if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: - text_encoder.train() - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - # 指定したステップ数でText Encoderの学習を止める - if global_step == args.stop_text_encoder_training: - print(f"stop text encoder training at step {global_step}") - if not args.gradient_checkpointing: - text_encoder.train(False) - text_encoder.requires_grad_(False) - - with accelerator.accumulate(unet): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - # Sample noise that we'll add to the latents - if args.train_noise_seed is not None: - torch.manual_seed(args.train_noise_seed) - torch.cuda.manual_seed(args.train_noise_seed) - # make same seed for each item in the batch by stacking them - single_noise = torch.randn_like(latents[0]) - noise = torch.stack([single_noise for _ in range(b_size)]) - noise = noise.to(latents.device) - elif args.seed_lock: - noise = train_tools.get_noise_from_latents(latents) - else: - noise = torch.randn_like(latents, device=latents.device) - - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) - # elif args.perlin_noise: - # noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently - - # Get the text embedding for conditioning - with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - if train_text_encoder: - params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) - else: - params_to_clip = unet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end_or_stepwise( - args, - False, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - unwrap_model(text_encoder), - unwrap_model(unet), - vae, - ) - - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) - accelerator.log(logs, step=global_step) - - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - # checking for saving is in util - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - unwrap_model(text_encoder), - unwrap_model(unet), - vae, - ) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - is_main_process = accelerator.is_main_process - if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) - - accelerator.end_training() - - if args.save_state and is_main_process: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_train_end( - args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae - ) - print("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, False, True) - train_util.add_training_arguments(parser, True) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--no_token_padding", - action="store_true", - help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)", - ) - parser.add_argument( - "--stop_text_encoder_training", - type=int, - default=None, - help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", - ) - - parser.add_argument( - "--sample_first", - action="store_true", - help="Sample first interval before training", - default=False - ) - - parser.add_argument( - "--name_replace", - type=str, - help="Replaces [name] in prompts. Used is sampling, training, and regs", - default=None - ) - - parser.add_argument( - "--train_noise_seed", - type=int, - help="Use custom seed for training noise", - default=None - ) - - parser.add_argument( - "--sample_only", - action="store_true", - help="Only generate samples. Used for generating training data with specific seeds to alter during training", - default=False - ) - - parser.add_argument( - "--seed_lock", - action="store_true", - help="Locks the seed to the latent images so the same latent will always have the same noise", - default=False - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/toolkit/buckets.py b/toolkit/buckets.py new file mode 100644 index 00000000..e7b6b1af --- /dev/null +++ b/toolkit/buckets.py @@ -0,0 +1,110 @@ +from typing import Type, List, Union + +BucketResolution = Type[{"width": int, "height": int}] + +# resolutions SDXL was trained on with a 1024x1024 base resolution +resolutions_1024: List[BucketResolution] = [ + # SDXL Base resolution + {"width": 1024, "height": 1024}, + # SDXL Resolutions, widescreen + {"width": 2048, "height": 512}, + {"width": 1984, "height": 512}, + {"width": 1920, "height": 512}, + {"width": 1856, "height": 512}, + {"width": 1792, "height": 576}, + {"width": 1728, "height": 576}, + {"width": 1664, "height": 576}, + {"width": 1600, "height": 640}, + {"width": 1536, "height": 640}, + {"width": 1472, "height": 704}, + {"width": 1408, "height": 704}, + {"width": 1344, "height": 704}, + {"width": 1344, "height": 768}, + {"width": 1280, "height": 768}, + {"width": 1216, "height": 832}, + {"width": 1152, "height": 832}, + {"width": 1152, "height": 896}, + {"width": 1088, "height": 896}, + {"width": 1088, "height": 960}, + {"width": 1024, "height": 960}, + # SDXL Resolutions, portrait + {"width": 960, "height": 1024}, + {"width": 960, "height": 1088}, + {"width": 896, "height": 1088}, + {"width": 896, "height": 1152}, + {"width": 832, "height": 1152}, + {"width": 832, "height": 1216}, + {"width": 768, "height": 1280}, + {"width": 768, "height": 1344}, + {"width": 704, "height": 1408}, + {"width": 704, "height": 1472}, + {"width": 640, "height": 1536}, + {"width": 640, "height": 1600}, + {"width": 576, "height": 1664}, + {"width": 576, "height": 1728}, + {"width": 576, "height": 1792}, + {"width": 512, "height": 1856}, + {"width": 512, "height": 1920}, + {"width": 512, "height": 1984}, + {"width": 512, "height": 2048}, +] + + +def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]: + # determine scaler form 1024 to resolution + scaler = resolution / 1024 + + bucket_size_list = [] + for bucket in resolutions_1024: + # must be divisible by 8 + width = int(bucket["width"] * scaler) + height = int(bucket["height"] * scaler) + if width % divisibility != 0: + width = width - (width % divisibility) + if height % divisibility != 0: + height = height - (height % divisibility) + bucket_size_list.append({"width": width, "height": height}) + + return bucket_size_list + + +def get_bucket_for_image_size( + width: int, + height: int, + bucket_size_list: List[BucketResolution] = None, + resolution: Union[int, None] = None +) -> BucketResolution: + if bucket_size_list is None and resolution is None: + raise ValueError("Must provide either bucket_size_list or resolution") + if bucket_size_list is None: + bucket_size_list = get_bucket_sizes(resolution=resolution) + + # Check for exact match first + for bucket in bucket_size_list: + if bucket["width"] == width and bucket["height"] == height: + return bucket + + # If exact match not found, find the closest bucket + closest_bucket = None + min_removed_pixels = float("inf") + + for bucket in bucket_size_list: + scale_w = bucket["width"] / width + scale_h = bucket["height"] / height + + # To minimize pixels, we use the larger scale factor to minimize the amount that has to be cropped. + scale = max(scale_w, scale_h) + + new_width = int(width * scale) + new_height = int(height * scale) + + removed_pixels = (new_width - bucket["width"]) * new_height + (new_height - bucket["height"]) * new_width + + if removed_pixels < min_removed_pixels: + min_removed_pixels = removed_pixels + closest_bucket = bucket + + if closest_bucket is None: + raise ValueError("No suitable bucket found") + + return closest_bucket diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index a6390368..e83c468b 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -52,7 +52,7 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): self.lora_name = lora_name self.scalar = torch.tensor(1.0) - if org_module.__class__.__name__ == "Conv2d": + if org_module.__class__.__name__ in CONV_MODULES: in_dim = org_module.in_channels out_dim = org_module.out_channels else: @@ -66,7 +66,7 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): # else: self.lora_dim = lora_dim - if org_module.__class__.__name__ == "Conv2d": + if org_module.__class__.__name__ in CONV_MODULES: kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index de2259bd..c591c5de 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -134,18 +134,7 @@ class StableDiffusion: # TODO handle other schedulers # sch = KDPM2DiscreteScheduler if self.noise_scheduler is None: - sch = DDPMScheduler - # do our own scheduler - prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" - scheduler = sch( - num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.0120, - beta_schedule="scaled_linear", - clip_sample=False, - prediction_type=prediction_type, - steps_offset=0 - ) + scheduler = get_sampler('ddpm') self.noise_scheduler = scheduler # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why