mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added shuffeling to prompts
This commit is contained in:
@@ -7,6 +7,10 @@ from tqdm import tqdm
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
import itertools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import SliderTargetConfig
|
||||
|
||||
|
||||
class ACTION_TYPES_SLIDER:
|
||||
@@ -226,6 +230,40 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
|
||||
return anchors
|
||||
|
||||
|
||||
def get_permutations(s):
|
||||
# Split the string by comma
|
||||
phrases = [phrase.strip() for phrase in s.split(',')]
|
||||
|
||||
# remove empty strings
|
||||
phrases = [phrase for phrase in phrases if len(phrase) > 0]
|
||||
|
||||
# Get all permutations
|
||||
permutations = list(itertools.permutations(phrases))
|
||||
|
||||
# Convert the tuples back to comma separated strings
|
||||
return [', '.join(permutation) for permutation in permutations]
|
||||
|
||||
|
||||
def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']:
|
||||
from toolkit.config_modules import SliderTargetConfig
|
||||
pos_permutations = get_permutations(target.positive)
|
||||
neg_permutations = get_permutations(target.negative)
|
||||
|
||||
permutations = []
|
||||
for pos, neg in itertools.product(pos_permutations, neg_permutations):
|
||||
permutations.append(
|
||||
SliderTargetConfig(
|
||||
target_class=target.target_class,
|
||||
positive=pos,
|
||||
negative=neg,
|
||||
multiplier=target.multiplier,
|
||||
weight=target.weight
|
||||
)
|
||||
)
|
||||
|
||||
return permutations
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
@@ -291,10 +329,6 @@ def encode_prompts_to_cache(
|
||||
return cache
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import SliderTargetConfig
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def build_prompt_pair_batch_from_cache(
|
||||
cache: PromptEmbedsCache,
|
||||
|
||||
Reference in New Issue
Block a user