mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added ability to just erase or enhance concepts from a model
This commit is contained in:
@@ -102,62 +102,78 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
cache[prompt] = train_util.encode_prompts(
|
||||
self.sd.tokenizer, self.sd.text_encoder, [prompt]
|
||||
)
|
||||
only_erase = len(target.positive.strip()) == 0
|
||||
only_enhance = len(target.negative.strip()) == 0
|
||||
|
||||
both = not only_erase and not only_enhance
|
||||
|
||||
if only_erase and only_enhance:
|
||||
raise ValueError("target must have at least one of positive or negative or both")
|
||||
# for slider we need to have an enhancer, an eraser, and then
|
||||
# an inverse with negative weights to balance the network
|
||||
# if we don't do this, we will get different contrast and focus.
|
||||
# we only perform actions of enhancing and erasing on the negative
|
||||
# todo work on way to do all of this in one shot
|
||||
prompt_pairs += [
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.positive],
|
||||
negative=cache[target.negative],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier,
|
||||
weight=target.weight
|
||||
),
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.negative],
|
||||
negative=cache[target.positive],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
),
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.negative],
|
||||
negative=cache[target.positive],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier,
|
||||
weight=target.weight
|
||||
),
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.positive],
|
||||
negative=cache[target.negative],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
|
||||
if both or only_erase:
|
||||
prompt_pairs += [
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.positive],
|
||||
negative=cache[target.negative],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier,
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
if both or only_enhance:
|
||||
prompt_pairs += [
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.negative],
|
||||
negative=cache[target.positive],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier,
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
if both:
|
||||
prompt_pairs += [
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.negative],
|
||||
negative=cache[target.positive],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
prompt_pairs += [
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
target_class=cache[target.target_class],
|
||||
positive=cache[target.positive],
|
||||
negative=cache[target.negative],
|
||||
neutral=cache[neutral],
|
||||
width=width,
|
||||
height=height,
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
|
||||
# setup anchors
|
||||
anchor_pairs = []
|
||||
|
||||
@@ -65,8 +65,8 @@ class ModelConfig:
|
||||
class SliderTargetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.target_class: str = kwargs.get('target_class', '')
|
||||
self.positive: str = kwargs.get('positive', None)
|
||||
self.negative: str = kwargs.get('negative', None)
|
||||
self.positive: str = kwargs.get('positive', '')
|
||||
self.negative: str = kwargs.get('negative', '')
|
||||
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
||||
self.weight: float = kwargs.get('weight', 1.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user