mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun
This commit is contained in:
@@ -12,6 +12,7 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete
|
||||
add_base_model_info_to_meta
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import random
|
||||
|
||||
|
||||
class GenerateConfig:
|
||||
@@ -41,6 +42,10 @@ class GenerateConfig:
|
||||
else:
|
||||
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
|
||||
|
||||
if kwargs.get('shuffle', False):
|
||||
# shuffle the prompts
|
||||
random.shuffle(self.prompts)
|
||||
|
||||
|
||||
class GenerateProcess(BaseProcess):
|
||||
process_id: int
|
||||
|
||||
@@ -1,21 +1,9 @@
|
||||
# ref:
|
||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from safetensors.torch import save_file, load_file
|
||||
import torch.utils.checkpoint as cp
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.layers import CheckpointGradients
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
from toolkit.prompt_utils import \
|
||||
@@ -256,9 +244,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
noise_scheduler.set_timesteps(1000)
|
||||
|
||||
current_timestep = noise_scheduler.timesteps[
|
||||
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
]
|
||||
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
# flush() # 4.2GB to 3GB on 512x512
|
||||
|
||||
@@ -401,10 +388,16 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
offset_neutral += offset
|
||||
|
||||
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
||||
loss = loss_function(
|
||||
target_latents,
|
||||
offset_neutral,
|
||||
) * prompt_pair_chunk.weight
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# match batch size
|
||||
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean() * prompt_pair_chunk.weight
|
||||
|
||||
loss.backward()
|
||||
loss_list.append(loss.item())
|
||||
|
||||
Reference in New Issue
Block a user