From c6675e280174927592fe4dbaab9113371faba2bf Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 19 Aug 2023 07:57:30 -0600 Subject: [PATCH] Added shuffeling to prompts --- config/examples/train_slider.example.yml | 4 +++ jobs/process/TrainSliderProcess.py | 3 ++ toolkit/config_modules.py | 16 +++++++-- toolkit/prompt_utils.py | 42 +++++++++++++++++++++--- 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/config/examples/train_slider.example.yml b/config/examples/train_slider.example.yml index 4c00dc9c..b3600917 100644 --- a/config/examples/train_slider.example.yml +++ b/config/examples/train_slider.example.yml @@ -184,6 +184,10 @@ config: # if you are doing more than one target it may be good to set less important ones # to a lower number like 0.1 so they don't outweigh the primary target weight: 1.0 + # shuffle the prompts split by the comma. We will run every combination randomly + # this will make the LoRA more robust. You probably want this on unless prompt order + # is important for some reason + shuffle: true # anchors are prompts that we will try to hold on to while training the slider diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 8aa5bd11..97a072ae 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -86,6 +86,9 @@ class TrainSliderProcess(BaseSDTrainProcess): # remove duplicates prompts_to_cache = list(dict.fromkeys(prompts_to_cache)) + # trim to max steps if max steps is lower than prompt count + prompts_to_cache = prompts_to_cache[:self.train_config.steps] + # encode them cache = encode_prompts_to_cache( prompt_list=prompts_to_cache, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 6b7b8a20..e41a1030 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -4,6 +4,7 @@ from typing import List, Optional import random + class SaveConfig: def __init__(self, **kwargs): self.save_every: int = kwargs.get('save_every', 1000) @@ -93,6 +94,7 @@ class SliderTargetConfig: self.negative: str = kwargs.get('negative', '') self.multiplier: float = kwargs.get('multiplier', 1.0) self.weight: float = kwargs.get('weight', 1.0) + self.shuffle: bool = kwargs.get('shuffle', False) class SliderConfigAnchors: @@ -105,8 +107,6 @@ class SliderConfigAnchors: class SliderConfig: def __init__(self, **kwargs): targets = kwargs.get('targets', []) - targets = [SliderTargetConfig(**target) for target in targets] - self.targets: List[SliderTargetConfig] = targets anchors = kwargs.get('anchors', []) anchors = [SliderConfigAnchors(**anchor) for anchor in anchors] self.anchors: List[SliderConfigAnchors] = anchors @@ -115,6 +115,18 @@ class SliderConfig: self.prompt_tensors: str = kwargs.get('prompt_tensors', None) self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) + # expand targets if shuffling + from toolkit.prompt_utils import get_slider_target_permutations + self.targets: List[SliderTargetConfig] = [] + targets = [SliderTargetConfig(**target) for target in targets] + # do permutations if shuffle is true + for target in targets: + if target.shuffle: + target_permutations = get_slider_target_permutations(target) + self.targets = self.targets + target_permutations + else: + self.targets.append(target) + class GenerateImageConfig: def __init__( diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index ea999274..aa5f42e5 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -7,6 +7,10 @@ from tqdm import tqdm from toolkit.stable_diffusion_model import PromptEmbeds from toolkit.train_tools import get_torch_dtype +import itertools + +if TYPE_CHECKING: + from toolkit.config_modules import SliderTargetConfig class ACTION_TYPES_SLIDER: @@ -226,6 +230,40 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc return anchors +def get_permutations(s): + # Split the string by comma + phrases = [phrase.strip() for phrase in s.split(',')] + + # remove empty strings + phrases = [phrase for phrase in phrases if len(phrase) > 0] + + # Get all permutations + permutations = list(itertools.permutations(phrases)) + + # Convert the tuples back to comma separated strings + return [', '.join(permutation) for permutation in permutations] + + +def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']: + from toolkit.config_modules import SliderTargetConfig + pos_permutations = get_permutations(target.positive) + neg_permutations = get_permutations(target.negative) + + permutations = [] + for pos, neg in itertools.product(pos_permutations, neg_permutations): + permutations.append( + SliderTargetConfig( + target_class=target.target_class, + positive=pos, + negative=neg, + multiplier=target.multiplier, + weight=target.weight + ) + ) + + return permutations + + if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion @@ -291,10 +329,6 @@ def encode_prompts_to_cache( return cache -if TYPE_CHECKING: - from toolkit.config_modules import SliderTargetConfig - - @torch.no_grad() def build_prompt_pair_batch_from_cache( cache: PromptEmbedsCache,