mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
added target weight to targets
This commit is contained in:
@@ -158,9 +158,10 @@ config:
|
|||||||
# negative is the prompt for the negative side of the slider and works the same as positive
|
# negative is the prompt for the negative side of the slider and works the same as positive
|
||||||
# it does not necessarily work the same as a negative prompt when generating images
|
# it does not necessarily work the same as a negative prompt when generating images
|
||||||
negative: "dog"
|
negative: "dog"
|
||||||
# LoRA weight to train this target. I recommend 1.0. Just leave it, it won't work
|
# the loss for this target is multiplied by this number.
|
||||||
# how you expect if you change it
|
# if you are doing more than one target it may be good to set less important ones
|
||||||
multiplier: 1.0
|
# to a lower number like 0.1 so they dont outweigh the primary target
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
# You can put any information you want here, and it will be saved in the model.
|
# You can put any information you want here, and it will be saved in the model.
|
||||||
# The below is an example, but you can put your grocery list in it if you want.
|
# The below is an example, but you can put your grocery list in it if you want.
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class SliderTargetConfig:
|
|||||||
self.positive: str = kwargs.get('positive', None)
|
self.positive: str = kwargs.get('positive', None)
|
||||||
self.negative: str = kwargs.get('negative', None)
|
self.negative: str = kwargs.get('negative', None)
|
||||||
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
||||||
|
self.weight: float = kwargs.get('weight', 1.0)
|
||||||
|
|
||||||
|
|
||||||
class SliderConfig:
|
class SliderConfig:
|
||||||
@@ -137,6 +138,7 @@ class EncodedPromptPair:
|
|||||||
height=512,
|
height=512,
|
||||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||||
multiplier=1.0,
|
multiplier=1.0,
|
||||||
|
weight=1.0
|
||||||
):
|
):
|
||||||
self.target_class = target_class
|
self.target_class = target_class
|
||||||
self.positive = positive
|
self.positive = positive
|
||||||
@@ -146,6 +148,7 @@ class EncodedPromptPair:
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.action: int = action
|
self.action: int = action
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
|
||||||
class TrainSliderProcess(BaseTrainProcess):
|
class TrainSliderProcess(BaseTrainProcess):
|
||||||
@@ -429,7 +432,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||||
multiplier=target.multiplier
|
multiplier=target.multiplier,
|
||||||
|
weight=target.weight
|
||||||
),
|
),
|
||||||
# erase inverted
|
# erase inverted
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -440,7 +444,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||||
multiplier=target.multiplier * -1.0
|
multiplier=target.multiplier * -1.0,
|
||||||
|
weight=target.weight
|
||||||
),
|
),
|
||||||
# enhance standard, swap pos neg
|
# enhance standard, swap pos neg
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -451,7 +456,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||||
multiplier=target.multiplier
|
multiplier=target.multiplier,
|
||||||
|
weight=target.weight
|
||||||
),
|
),
|
||||||
# enhance inverted
|
# enhance inverted
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -462,7 +468,8 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||||
multiplier=target.multiplier * -1.0
|
multiplier=target.multiplier * -1.0,
|
||||||
|
weight=target.weight
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -494,6 +501,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
neutral = prompt_pair.neutral
|
neutral = prompt_pair.neutral
|
||||||
negative = prompt_pair.negative
|
negative = prompt_pair.negative
|
||||||
positive = prompt_pair.positive
|
positive = prompt_pair.positive
|
||||||
|
weight = prompt_pair.weight
|
||||||
|
|
||||||
# set network multiplier
|
# set network multiplier
|
||||||
self.network.multiplier = prompt_pair.multiplier
|
self.network.multiplier = prompt_pair.multiplier
|
||||||
@@ -621,7 +629,7 @@ class TrainSliderProcess(BaseTrainProcess):
|
|||||||
loss = loss_function(
|
loss = loss_function(
|
||||||
target_latents,
|
target_latents,
|
||||||
offset_neutral,
|
offset_neutral,
|
||||||
)
|
) * weight
|
||||||
|
|
||||||
loss_float = loss.item()
|
loss_float = loss.item()
|
||||||
if self.train_config.optimizer.startswith('dadaptation'):
|
if self.train_config.optimizer.startswith('dadaptation'):
|
||||||
|
|||||||
Reference in New Issue
Block a user