Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun

This commit is contained in:
Jaret Burkett
2023-08-19 05:54:22 -06:00
parent 80e2f4a2a4
commit 90eedb78bf
7 changed files with 239 additions and 35 deletions

View File

@@ -12,6 +12,7 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete
add_base_model_info_to_meta
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
import random
class GenerateConfig:
@@ -41,6 +42,10 @@ class GenerateConfig:
else:
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
if kwargs.get('shuffle', False):
# shuffle the prompts
random.shuffle(self.prompts)
class GenerateProcess(BaseProcess):
process_id: int

View File

@@ -1,21 +1,9 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import random
from collections import OrderedDict
import os
from typing import Optional, Union
from safetensors.torch import save_file, load_file
import torch.utils.checkpoint as cp
from tqdm import tqdm
from toolkit.config_modules import SliderConfig
from toolkit.layers import CheckpointGradients
from toolkit.paths import REPOS_ROOT
import sys
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
from toolkit import train_tools
from toolkit.prompt_utils import \
@@ -256,9 +244,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
noise_scheduler.set_timesteps(1000)
current_timestep = noise_scheduler.timesteps[
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
]
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
# flush() # 4.2GB to 3GB on 512x512
@@ -401,10 +388,16 @@ class TrainSliderProcess(BaseSDTrainProcess):
offset_neutral += offset
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
loss = loss_function(
target_latents,
offset_neutral,
) * prompt_pair_chunk.weight
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# match batch size
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean() * prompt_pair_chunk.weight
loss.backward()
loss_list.append(loss.item())