Added a way to add a t2i adapter guided slider training for more consitant images

This commit is contained in:
Jaret Burkett
2023-09-28 14:08:56 -06:00
parent c5d49ba661
commit 8509da60cb
3 changed files with 189 additions and 34 deletions

View File

@@ -105,6 +105,18 @@ class EncodedPromptPair:
self.both_targets = self.both_targets.to(*args, **kwargs)
return self
def detach(self):
self.target_class = self.target_class.detach()
self.target_class_with_neutral = self.target_class_with_neutral.detach()
self.positive_target = self.positive_target.detach()
self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
self.negative_target = self.negative_target.detach()
self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
self.neutral = self.neutral.detach()
self.empty_prompt = self.empty_prompt.detach()
self.both_targets = self.both_targets.detach()
return self
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
@@ -267,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
return anchors
def get_permutations(s):
def get_permutations(s, max_permutations=8):
# Split the string by comma
phrases = [phrase.strip() for phrase in s.split(',')]
# remove empty strings
phrases = [phrase for phrase in phrases if len(phrase) > 0]
# shuffle the list
random.shuffle(phrases)
# Get all permutations
permutations = list(itertools.permutations(phrases))
permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
# Convert the tuples back to comma separated strings
return [', '.join(permutation) for permutation in permutations]
@@ -283,8 +297,8 @@ def get_permutations(s):
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
from toolkit.config_modules import SliderTargetConfig
pos_permutations = get_permutations(target.positive)
neg_permutations = get_permutations(target.negative)
pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
permutations = []
for pos, neg in itertools.product(pos_permutations, neg_permutations):