mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 06:49:08 +00:00
Added a way to add a t2i adapter guided slider training for more consitant images
This commit is contained in:
@@ -186,18 +186,23 @@ class SliderConfig:
|
||||
self.prompt_file: str = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
|
||||
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
|
||||
self.high_ram = kwargs.get('high_ram', False)
|
||||
|
||||
# expand targets if shuffling
|
||||
from toolkit.prompt_utils import get_slider_target_permutations
|
||||
self.targets: List[SliderTargetConfig] = []
|
||||
targets = [SliderTargetConfig(**target) for target in targets]
|
||||
# do permutations if shuffle is true
|
||||
print(f"Building slider targets")
|
||||
for target in targets:
|
||||
if target.shuffle:
|
||||
target_permutations = get_slider_target_permutations(target)
|
||||
target_permutations = get_slider_target_permutations(target, max_permutations=100)
|
||||
self.targets = self.targets + target_permutations
|
||||
else:
|
||||
self.targets.append(target)
|
||||
print(f"Built {len(self.targets)} slider targets (with permutations)")
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
|
||||
@@ -105,6 +105,18 @@ class EncodedPromptPair:
|
||||
self.both_targets = self.both_targets.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
self.target_class = self.target_class.detach()
|
||||
self.target_class_with_neutral = self.target_class_with_neutral.detach()
|
||||
self.positive_target = self.positive_target.detach()
|
||||
self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
|
||||
self.negative_target = self.negative_target.detach()
|
||||
self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
|
||||
self.neutral = self.neutral.detach()
|
||||
self.empty_prompt = self.empty_prompt.detach()
|
||||
self.both_targets = self.both_targets.detach()
|
||||
return self
|
||||
|
||||
|
||||
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
||||
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
|
||||
@@ -267,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
|
||||
return anchors
|
||||
|
||||
|
||||
def get_permutations(s):
|
||||
def get_permutations(s, max_permutations=8):
|
||||
# Split the string by comma
|
||||
phrases = [phrase.strip() for phrase in s.split(',')]
|
||||
|
||||
# remove empty strings
|
||||
phrases = [phrase for phrase in phrases if len(phrase) > 0]
|
||||
# shuffle the list
|
||||
random.shuffle(phrases)
|
||||
|
||||
# Get all permutations
|
||||
permutations = list(itertools.permutations(phrases))
|
||||
permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
|
||||
|
||||
# Convert the tuples back to comma separated strings
|
||||
return [', '.join(permutation) for permutation in permutations]
|
||||
@@ -283,8 +297,8 @@ def get_permutations(s):
|
||||
|
||||
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
|
||||
from toolkit.config_modules import SliderTargetConfig
|
||||
pos_permutations = get_permutations(target.positive)
|
||||
neg_permutations = get_permutations(target.negative)
|
||||
pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
|
||||
neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
|
||||
|
||||
permutations = []
|
||||
for pos, neg in itertools.product(pos_permutations, neg_permutations):
|
||||
|
||||
Reference in New Issue
Block a user