Added ability to just erase or enhance concepts from a model

This commit is contained in:
Jaret Burkett
2023-07-24 17:33:45 -06:00
parent 61dd818608
commit 7032717294
2 changed files with 68 additions and 52 deletions

View File

@@ -102,62 +102,78 @@ class TrainSliderProcess(BaseSDTrainProcess):
cache[prompt] = train_util.encode_prompts( cache[prompt] = train_util.encode_prompts(
self.sd.tokenizer, self.sd.text_encoder, [prompt] self.sd.tokenizer, self.sd.text_encoder, [prompt]
) )
only_erase = len(target.positive.strip()) == 0
only_enhance = len(target.negative.strip()) == 0
both = not only_erase and not only_enhance
if only_erase and only_enhance:
raise ValueError("target must have at least one of positive or negative or both")
# for slider we need to have an enhancer, an eraser, and then # for slider we need to have an enhancer, an eraser, and then
# an inverse with negative weights to balance the network # an inverse with negative weights to balance the network
# if we don't do this, we will get different contrast and focus. # if we don't do this, we will get different contrast and focus.
# we only perform actions of enhancing and erasing on the negative # we only perform actions of enhancing and erasing on the negative
# todo work on way to do all of this in one shot # todo work on way to do all of this in one shot
prompt_pairs += [
# erase standard if both or only_erase:
EncodedPromptPair( prompt_pairs += [
target_class=cache[target.target_class], # erase standard
positive=cache[target.positive], EncodedPromptPair(
negative=cache[target.negative], target_class=cache[target.target_class],
neutral=cache[neutral], positive=cache[target.positive],
width=width, negative=cache[target.negative],
height=height, neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, width=width,
multiplier=target.multiplier, height=height,
weight=target.weight action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
), multiplier=target.multiplier,
# erase inverted weight=target.weight
EncodedPromptPair( ),
target_class=cache[target.target_class], ]
positive=cache[target.negative], if both or only_enhance:
negative=cache[target.positive], prompt_pairs += [
neutral=cache[neutral], # enhance standard, swap pos neg
width=width, EncodedPromptPair(
height=height, target_class=cache[target.target_class],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, positive=cache[target.negative],
multiplier=target.multiplier * -1.0, negative=cache[target.positive],
weight=target.weight neutral=cache[neutral],
), width=width,
# enhance standard, swap pos neg height=height,
EncodedPromptPair( action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
target_class=cache[target.target_class], multiplier=target.multiplier,
positive=cache[target.negative], weight=target.weight
negative=cache[target.positive], ),
neutral=cache[neutral], ]
width=width, if both:
height=height, prompt_pairs += [
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, # erase inverted
multiplier=target.multiplier, EncodedPromptPair(
weight=target.weight target_class=cache[target.target_class],
), positive=cache[target.negative],
# enhance inverted negative=cache[target.positive],
EncodedPromptPair( neutral=cache[neutral],
target_class=cache[target.target_class], width=width,
positive=cache[target.positive], height=height,
negative=cache[target.negative], action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
neutral=cache[neutral], multiplier=target.multiplier * -1.0,
width=width, weight=target.weight
height=height, ),
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, ]
multiplier=target.multiplier * -1.0, prompt_pairs += [
weight=target.weight # enhance inverted
), EncodedPromptPair(
] target_class=cache[target.target_class],
positive=cache[target.positive],
negative=cache[target.negative],
neutral=cache[neutral],
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]
# setup anchors # setup anchors
anchor_pairs = [] anchor_pairs = []

View File

@@ -65,8 +65,8 @@ class ModelConfig:
class SliderTargetConfig: class SliderTargetConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.target_class: str = kwargs.get('target_class', '') self.target_class: str = kwargs.get('target_class', '')
self.positive: str = kwargs.get('positive', None) self.positive: str = kwargs.get('positive', '')
self.negative: str = kwargs.get('negative', None) self.negative: str = kwargs.get('negative', '')
self.multiplier: float = kwargs.get('multiplier', 1.0) self.multiplier: float = kwargs.get('multiplier', 1.0)
self.weight: float = kwargs.get('weight', 1.0) self.weight: float = kwargs.get('weight', 1.0)