mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
Added my annotator/preprocessor and improved network jitter on reference trainer
This commit is contained in:
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -4,3 +4,6 @@
|
||||
[submodule "repositories/leco"]
|
||||
path = repositories/leco
|
||||
url = https://github.com/p1atdev/LECO
|
||||
[submodule "repositories/batch_annotator"]
|
||||
path = repositories/batch_annotator
|
||||
url = https://github.com/ostris/batch-annotator
|
||||
|
||||
@@ -16,6 +16,7 @@ from toolkit import train_tools
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
import random
|
||||
from toolkit.basic import value_map
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -91,11 +92,18 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
network_neg_weight = network_neg_weight.item()
|
||||
|
||||
# get an array of random floats between -weight_jitter and weight_jitter
|
||||
loss_jitter_multiplier = 1.0
|
||||
weight_jitter = self.slider_config.weight_jitter
|
||||
if weight_jitter > 0.0:
|
||||
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
||||
orig_network_pos_weight = network_pos_weight
|
||||
network_pos_weight += jitter_list
|
||||
network_neg_weight += (jitter_list * -1.0)
|
||||
# penalize the loss for its distance from network_pos_weight
|
||||
# a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter
|
||||
# so the loss_jitter_multiplier needs to be 0.4
|
||||
loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0)
|
||||
|
||||
|
||||
# if items in network_weight list are tensors, convert them to floats
|
||||
|
||||
@@ -210,7 +218,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
loss = loss.mean() * loss_jitter_multiplier
|
||||
loss_slide_float = loss.item()
|
||||
|
||||
loss_float = loss.item()
|
||||
|
||||
1
repositories/batch_annotator
Submodule
1
repositories/batch_annotator
Submodule
Submodule repositories/batch_annotator added at 420e142f6a
Reference in New Issue
Block a user