Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun

This commit is contained in:
Jaret Burkett
2023-08-19 05:54:22 -06:00
parent 80e2f4a2a4
commit 90eedb78bf
7 changed files with 239 additions and 35 deletions

View File

@@ -8,11 +8,12 @@ from torch.utils.data import ConcatDataset, DataLoader
from toolkit.data_loader import PairedImageDataset
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
from toolkit.train_tools import get_torch_dtype
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
from toolkit import train_tools
import torch
from jobs.process import BaseSDTrainProcess
import random
def flush():
@@ -41,6 +42,7 @@ class DatasetConfig:
class ReferenceSliderConfig:
def __init__(self, **kwargs):
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
@@ -98,10 +100,19 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
with torch.no_grad():
imgs, prompts, network_weights = batch
network_pos_weight, network_neg_weight = network_weights
if isinstance(network_pos_weight, torch.Tensor):
network_pos_weight = network_pos_weight.item()
if isinstance(network_neg_weight, torch.Tensor):
network_neg_weight = network_neg_weight.item()
# get an array of random floats between -weight_jitter and weight_jitter
weight_jitter = self.slider_config.weight_jitter
if weight_jitter > 0.0:
jitter_list = random.uniform(-weight_jitter, weight_jitter)
network_pos_weight += jitter_list
network_neg_weight += jitter_list
# if items in network_weight list are tensors, convert them to floats
dtype = get_torch_dtype(self.train_config.dtype)
@@ -211,6 +222,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss = loss.mean([1, 2, 3])
# todo add snr gamma here
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
loss_slide_float = loss.item()