added target weight to targets

This commit is contained in:
Jaret Burkett
2023-07-23 14:08:37 -06:00
parent 452f2a6da2
commit 9a2819900c
2 changed files with 17 additions and 8 deletions

View File

@@ -116,6 +116,7 @@ class SliderTargetConfig:
self.positive: str = kwargs.get('positive', None)
self.negative: str = kwargs.get('negative', None)
self.multiplier: float = kwargs.get('multiplier', 1.0)
self.weight: float = kwargs.get('weight', 1.0)
class SliderConfig:
@@ -137,6 +138,7 @@ class EncodedPromptPair:
height=512,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0,
weight=1.0
):
self.target_class = target_class
self.positive = positive
@@ -146,6 +148,7 @@ class EncodedPromptPair:
self.height = height
self.action: int = action
self.multiplier = multiplier
self.weight = weight
class TrainSliderProcess(BaseTrainProcess):
@@ -429,7 +432,8 @@ class TrainSliderProcess(BaseTrainProcess):
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier
multiplier=target.multiplier,
weight=target.weight
),
# erase inverted
EncodedPromptPair(
@@ -440,7 +444,8 @@ class TrainSliderProcess(BaseTrainProcess):
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier * -1.0
multiplier=target.multiplier * -1.0,
weight=target.weight
),
# enhance standard, swap pos neg
EncodedPromptPair(
@@ -451,7 +456,8 @@ class TrainSliderProcess(BaseTrainProcess):
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier
multiplier=target.multiplier,
weight=target.weight
),
# enhance inverted
EncodedPromptPair(
@@ -462,7 +468,8 @@ class TrainSliderProcess(BaseTrainProcess):
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier * -1.0
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
@@ -494,6 +501,7 @@ class TrainSliderProcess(BaseTrainProcess):
neutral = prompt_pair.neutral
negative = prompt_pair.negative
positive = prompt_pair.positive
weight = prompt_pair.weight
# set network multiplier
self.network.multiplier = prompt_pair.multiplier
@@ -621,7 +629,7 @@ class TrainSliderProcess(BaseTrainProcess):
loss = loss_function(
target_latents,
offset_neutral,
)
) * weight
loss_float = loss.item()
if self.train_config.optimizer.startswith('dadaptation'):