Added ability to just erase or enhance concepts from a model

This commit is contained in:
Jaret Burkett
2023-07-24 17:33:45 -06:00
parent 61dd818608
commit 7032717294
2 changed files with 68 additions and 52 deletions

View File

@@ -102,62 +102,78 @@ class TrainSliderProcess(BaseSDTrainProcess):
cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt]
)
only_erase = len(target.positive.strip()) == 0
only_enhance = len(target.negative.strip()) == 0
both = not only_erase and not only_enhance
if only_erase and only_enhance:
raise ValueError("target must have at least one of positive or negative or both")
# for slider we need to have an enhancer, an eraser, and then
# an inverse with negative weights to balance the network
# if we don't do this, we will get different contrast and focus.
# we only perform actions of enhancing and erasing on the negative
# todo work on way to do all of this in one shot
prompt_pairs += [
# erase standard
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.positive],
negative=cache[target.negative],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier,
weight=target.weight
),
# erase inverted
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.negative],
negative=cache[target.positive],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier * -1.0,
weight=target.weight
),
# enhance standard, swap pos neg
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.negative],
negative=cache[target.positive],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier,
weight=target.weight
),
# enhance inverted
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.positive],
negative=cache[target.negative],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
if both or only_erase:
prompt_pairs += [
# erase standard
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.positive],
negative=cache[target.negative],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier,
weight=target.weight
),
]
if both or only_enhance:
prompt_pairs += [
# enhance standard, swap pos neg
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.negative],
negative=cache[target.positive],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier,
weight=target.weight
),
]
if both:
prompt_pairs += [
# erase inverted
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.negative],
negative=cache[target.positive],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
prompt_pairs += [
# enhance inverted
EncodedPromptPair(
target_class=cache[target.target_class],
positive=cache[target.positive],
negative=cache[target.negative],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
# setup anchors
anchor_pairs = []

View File

@@ -65,8 +65,8 @@ class ModelConfig:
class SliderTargetConfig:
def __init__(self, **kwargs):
self.target_class: str = kwargs.get('target_class', '')
self.positive: str = kwargs.get('positive', None)
self.negative: str = kwargs.get('negative', None)
self.positive: str = kwargs.get('positive', '')
self.negative: str = kwargs.get('negative', '')
self.multiplier: float = kwargs.get('multiplier', 1.0)
self.weight: float = kwargs.get('weight', 1.0)