mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-22 23:39:21 +00:00
Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun
This commit is contained in:
@@ -8,11 +8,12 @@ from torch.utils.data import ConcatDataset, DataLoader
|
||||
from toolkit.data_loader import PairedImageDataset
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
import random
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -41,6 +42,7 @@ class DatasetConfig:
|
||||
class ReferenceSliderConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
||||
self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
|
||||
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
||||
|
||||
|
||||
@@ -98,10 +100,19 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
with torch.no_grad():
|
||||
imgs, prompts, network_weights = batch
|
||||
network_pos_weight, network_neg_weight = network_weights
|
||||
|
||||
if isinstance(network_pos_weight, torch.Tensor):
|
||||
network_pos_weight = network_pos_weight.item()
|
||||
if isinstance(network_neg_weight, torch.Tensor):
|
||||
network_neg_weight = network_neg_weight.item()
|
||||
|
||||
# get an array of random floats between -weight_jitter and weight_jitter
|
||||
weight_jitter = self.slider_config.weight_jitter
|
||||
if weight_jitter > 0.0:
|
||||
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
||||
network_pos_weight += jitter_list
|
||||
network_neg_weight += jitter_list
|
||||
|
||||
# if items in network_weight list are tensors, convert them to floats
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -211,6 +222,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
# todo add snr gamma here
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
loss_slide_float = loss.item()
|
||||
|
||||
Reference in New Issue
Block a user