From 70327172944e748e5a429dd08f2458e944ba8e6f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 24 Jul 2023 17:33:45 -0600 Subject: [PATCH] Added ability to just erase or enhance concepts from a model --- jobs/process/TrainSliderProcess.py | 116 ++++++++++++++++------------- toolkit/config_modules.py | 4 +- 2 files changed, 68 insertions(+), 52 deletions(-) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 70ddbcd6..2bed4b29 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -102,62 +102,78 @@ class TrainSliderProcess(BaseSDTrainProcess): cache[prompt] = train_util.encode_prompts( self.sd.tokenizer, self.sd.text_encoder, [prompt] ) + only_erase = len(target.positive.strip()) == 0 + only_enhance = len(target.negative.strip()) == 0 + both = not only_erase and not only_enhance + + if only_erase and only_enhance: + raise ValueError("target must have at least one of positive or negative or both") # for slider we need to have an enhancer, an eraser, and then # an inverse with negative weights to balance the network # if we don't do this, we will get different contrast and focus. # we only perform actions of enhancing and erasing on the negative # todo work on way to do all of this in one shot - prompt_pairs += [ - # erase standard - EncodedPromptPair( - target_class=cache[target.target_class], - positive=cache[target.positive], - negative=cache[target.negative], - neutral=cache[neutral], - width=width, - height=height, - action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=target.multiplier, - weight=target.weight - ), - # erase inverted - EncodedPromptPair( - target_class=cache[target.target_class], - positive=cache[target.negative], - negative=cache[target.positive], - neutral=cache[neutral], - width=width, - height=height, - action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=target.multiplier * -1.0, - weight=target.weight - ), - # enhance standard, swap pos neg - EncodedPromptPair( - target_class=cache[target.target_class], - positive=cache[target.negative], - negative=cache[target.positive], - neutral=cache[neutral], - width=width, - height=height, - action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, - multiplier=target.multiplier, - weight=target.weight - ), - # enhance inverted - EncodedPromptPair( - target_class=cache[target.target_class], - positive=cache[target.positive], - negative=cache[target.negative], - neutral=cache[neutral], - width=width, - height=height, - action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, - multiplier=target.multiplier * -1.0, - weight=target.weight - ), - ] + + if both or only_erase: + prompt_pairs += [ + # erase standard + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier, + weight=target.weight + ), + ] + if both or only_enhance: + prompt_pairs += [ + # enhance standard, swap pos neg + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier, + weight=target.weight + ), + ] + if both: + prompt_pairs += [ + # erase inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.negative], + negative=cache[target.positive], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] + prompt_pairs += [ + # enhance inverted + EncodedPromptPair( + target_class=cache[target.target_class], + positive=cache[target.positive], + negative=cache[target.negative], + neutral=cache[neutral], + width=width, + height=height, + action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, + multiplier=target.multiplier * -1.0, + weight=target.weight + ), + ] # setup anchors anchor_pairs = [] diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4a22f331..a94de093 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -65,8 +65,8 @@ class ModelConfig: class SliderTargetConfig: def __init__(self, **kwargs): self.target_class: str = kwargs.get('target_class', '') - self.positive: str = kwargs.get('positive', None) - self.negative: str = kwargs.get('negative', None) + self.positive: str = kwargs.get('positive', '') + self.negative: str = kwargs.get('negative', '') self.multiplier: float = kwargs.get('multiplier', 1.0) self.weight: float = kwargs.get('weight', 1.0)