mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
added target weight to targets
This commit is contained in:
@@ -116,6 +116,7 @@ class SliderTargetConfig:
|
||||
self.positive: str = kwargs.get('positive', None)
|
||||
self.negative: str = kwargs.get('negative', None)
|
||||
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
||||
self.weight: float = kwargs.get('weight', 1.0)
|
||||
|
||||
|
||||
class SliderConfig:
|
||||
@@ -137,6 +138,7 @@ class EncodedPromptPair:
|
||||
height=512,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=1.0,
|
||||
weight=1.0
|
||||
):
|
||||
self.target_class = target_class
|
||||
self.positive = positive
|
||||
@@ -146,6 +148,7 @@ class EncodedPromptPair:
|
||||
self.height = height
|
||||
self.action: int = action
|
||||
self.multiplier = multiplier
|
||||
self.weight = weight
|
||||
|
||||
|
||||
class TrainSliderProcess(BaseTrainProcess):
|
||||
@@ -429,7 +432,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier
|
||||
multiplier=target.multiplier,
|
||||
weight=target.weight
|
||||
),
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
@@ -440,7 +444,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier * -1.0
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
),
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
@@ -451,7 +456,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier
|
||||
multiplier=target.multiplier,
|
||||
weight=target.weight
|
||||
),
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
@@ -462,7 +468,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier * -1.0
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
|
||||
@@ -494,6 +501,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
neutral = prompt_pair.neutral
|
||||
negative = prompt_pair.negative
|
||||
positive = prompt_pair.positive
|
||||
weight = prompt_pair.weight
|
||||
|
||||
# set network multiplier
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
@@ -621,7 +629,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
||||
loss = loss_function(
|
||||
target_latents,
|
||||
offset_neutral,
|
||||
)
|
||||
) * weight
|
||||
|
||||
loss_float = loss.item()
|
||||
if self.train_config.optimizer.startswith('dadaptation'):
|
||||
|
||||
Reference in New Issue
Block a user