Added shuffeling to prompts

This commit is contained in:
Jaret Burkett
2023-08-19 07:57:30 -06:00
parent 90eedb78bf
commit c6675e2801
4 changed files with 59 additions and 6 deletions

View File

@@ -7,6 +7,10 @@ from tqdm import tqdm
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
import itertools
if TYPE_CHECKING:
from toolkit.config_modules import SliderTargetConfig
class ACTION_TYPES_SLIDER:
@@ -226,6 +230,40 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
return anchors
def get_permutations(s):
# Split the string by comma
phrases = [phrase.strip() for phrase in s.split(',')]
# remove empty strings
phrases = [phrase for phrase in phrases if len(phrase) > 0]
# Get all permutations
permutations = list(itertools.permutations(phrases))
# Convert the tuples back to comma separated strings
return [', '.join(permutation) for permutation in permutations]
def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']:
from toolkit.config_modules import SliderTargetConfig
pos_permutations = get_permutations(target.positive)
neg_permutations = get_permutations(target.negative)
permutations = []
for pos, neg in itertools.product(pos_permutations, neg_permutations):
permutations.append(
SliderTargetConfig(
target_class=target.target_class,
positive=pos,
negative=neg,
multiplier=target.multiplier,
weight=target.weight
)
)
return permutations
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
@@ -291,10 +329,6 @@ def encode_prompts_to_cache(
return cache
if TYPE_CHECKING:
from toolkit.config_modules import SliderTargetConfig
@torch.no_grad()
def build_prompt_pair_batch_from_cache(
cache: PromptEmbedsCache,