Added my annotator/preprocessor and improved network jitter on reference trainer

This commit is contained in:
Jaret Burkett
2023-09-03 16:43:51 -06:00
parent 2a40937b4f
commit 7cd6945082
3 changed files with 13 additions and 1 deletions

3
.gitmodules vendored
View File

@@ -4,3 +4,6 @@
[submodule "repositories/leco"]
path = repositories/leco
url = https://github.com/p1atdev/LECO
[submodule "repositories/batch_annotator"]
path = repositories/batch_annotator
url = https://github.com/ostris/batch-annotator

View File

@@ -16,6 +16,7 @@ from toolkit import train_tools
import torch
from jobs.process import BaseSDTrainProcess
import random
from toolkit.basic import value_map
def flush():
@@ -91,11 +92,18 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
network_neg_weight = network_neg_weight.item()
# get an array of random floats between -weight_jitter and weight_jitter
loss_jitter_multiplier = 1.0
weight_jitter = self.slider_config.weight_jitter
if weight_jitter > 0.0:
jitter_list = random.uniform(-weight_jitter, weight_jitter)
orig_network_pos_weight = network_pos_weight
network_pos_weight += jitter_list
network_neg_weight += (jitter_list * -1.0)
# penalize the loss for its distance from network_pos_weight
# a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter
# so the loss_jitter_multiplier needs to be 0.4
loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0)
# if items in network_weight list are tensors, convert them to floats
@@ -210,7 +218,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
loss = loss.mean() * loss_jitter_multiplier
loss_slide_float = loss.item()
loss_float = loss.item()