From 7cd6945082aeaf9e9760880279cec6391577755f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 3 Sep 2023 16:43:51 -0600 Subject: [PATCH] Added my annotator/preprocessor and improved network jitter on reference trainer --- .gitmodules | 3 +++ .../ImageReferenceSliderTrainerProcess.py | 10 +++++++++- repositories/batch_annotator | 1 + 3 files changed, 13 insertions(+), 1 deletion(-) create mode 160000 repositories/batch_annotator diff --git a/.gitmodules b/.gitmodules index 18079356..ebf3e531 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "repositories/leco"] path = repositories/leco url = https://github.com/p1atdev/LECO +[submodule "repositories/batch_annotator"] + path = repositories/batch_annotator + url = https://github.com/ostris/batch-annotator diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py index 9e28a7c4..99c41e0d 100644 --- a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -16,6 +16,7 @@ from toolkit import train_tools import torch from jobs.process import BaseSDTrainProcess import random +from toolkit.basic import value_map def flush(): @@ -91,11 +92,18 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): network_neg_weight = network_neg_weight.item() # get an array of random floats between -weight_jitter and weight_jitter + loss_jitter_multiplier = 1.0 weight_jitter = self.slider_config.weight_jitter if weight_jitter > 0.0: jitter_list = random.uniform(-weight_jitter, weight_jitter) + orig_network_pos_weight = network_pos_weight network_pos_weight += jitter_list network_neg_weight += (jitter_list * -1.0) + # penalize the loss for its distance from network_pos_weight + # a jitter_list of abs(3.0) on a weight of 5.0 is a 60% jitter + # so the loss_jitter_multiplier needs to be 0.4 + loss_jitter_multiplier = value_map(abs(jitter_list), 0.0, weight_jitter, 1.0, 0.0) + # if items in network_weight list are tensors, convert them to floats @@ -210,7 +218,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): # add min_snr_gamma loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) - loss = loss.mean() + loss = loss.mean() * loss_jitter_multiplier loss_slide_float = loss.item() loss_float = loss.item() diff --git a/repositories/batch_annotator b/repositories/batch_annotator new file mode 160000 index 00000000..420e142f --- /dev/null +++ b/repositories/batch_annotator @@ -0,0 +1 @@ +Subproject commit 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf