added target weight to targets

This commit is contained in:
Jaret Burkett
2023-07-23 14:08:37 -06:00
parent 452f2a6da2
commit 9a2819900c
2 changed files with 17 additions and 8 deletions

View File

@@ -158,9 +158,10 @@ config:
# negative is the prompt for the negative side of the slider and works the same as positive # negative is the prompt for the negative side of the slider and works the same as positive
# it does not necessarily work the same as a negative prompt when generating images # it does not necessarily work the same as a negative prompt when generating images
negative: "dog" negative: "dog"
# LoRA weight to train this target. I recommend 1.0. Just leave it, it won't work # the loss for this target is multiplied by this number.
# how you expect if you change it # if you are doing more than one target it may be good to set less important ones
multiplier: 1.0 # to a lower number like 0.1 so they dont outweigh the primary target
weight: 1.0
# You can put any information you want here, and it will be saved in the model. # You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want. # The below is an example, but you can put your grocery list in it if you want.

View File

@@ -116,6 +116,7 @@ class SliderTargetConfig:
self.positive: str = kwargs.get('positive', None) self.positive: str = kwargs.get('positive', None)
self.negative: str = kwargs.get('negative', None) self.negative: str = kwargs.get('negative', None)
self.multiplier: float = kwargs.get('multiplier', 1.0) self.multiplier: float = kwargs.get('multiplier', 1.0)
self.weight: float = kwargs.get('weight', 1.0)
class SliderConfig: class SliderConfig:
@@ -137,6 +138,7 @@ class EncodedPromptPair:
height=512, height=512,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0, multiplier=1.0,
weight=1.0
): ):
self.target_class = target_class self.target_class = target_class
self.positive = positive self.positive = positive
@@ -146,6 +148,7 @@ class EncodedPromptPair:
self.height = height self.height = height
self.action: int = action self.action: int = action
self.multiplier = multiplier self.multiplier = multiplier
self.weight = weight
class TrainSliderProcess(BaseTrainProcess): class TrainSliderProcess(BaseTrainProcess):
@@ -429,7 +432,8 @@ class TrainSliderProcess(BaseTrainProcess):
width=width, width=width,
height=height, height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier multiplier=target.multiplier,
weight=target.weight
), ),
# erase inverted # erase inverted
EncodedPromptPair( EncodedPromptPair(
@@ -440,7 +444,8 @@ class TrainSliderProcess(BaseTrainProcess):
width=width, width=width,
height=height, height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier * -1.0 multiplier=target.multiplier * -1.0,
weight=target.weight
), ),
# enhance standard, swap pos neg # enhance standard, swap pos neg
EncodedPromptPair( EncodedPromptPair(
@@ -451,7 +456,8 @@ class TrainSliderProcess(BaseTrainProcess):
width=width, width=width,
height=height, height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier multiplier=target.multiplier,
weight=target.weight
), ),
# enhance inverted # enhance inverted
EncodedPromptPair( EncodedPromptPair(
@@ -462,7 +468,8 @@ class TrainSliderProcess(BaseTrainProcess):
width=width, width=width,
height=height, height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier * -1.0 multiplier=target.multiplier * -1.0,
weight=target.weight
), ),
] ]
@@ -494,6 +501,7 @@ class TrainSliderProcess(BaseTrainProcess):
neutral = prompt_pair.neutral neutral = prompt_pair.neutral
negative = prompt_pair.negative negative = prompt_pair.negative
positive = prompt_pair.positive positive = prompt_pair.positive
weight = prompt_pair.weight
# set network multiplier # set network multiplier
self.network.multiplier = prompt_pair.multiplier self.network.multiplier = prompt_pair.multiplier
@@ -621,7 +629,7 @@ class TrainSliderProcess(BaseTrainProcess):
loss = loss_function( loss = loss_function(
target_latents, target_latents,
offset_neutral, offset_neutral,
) ) * weight
loss_float = loss.item() loss_float = loss.item()
if self.train_config.optimizer.startswith('dadaptation'): if self.train_config.optimizer.startswith('dadaptation'):