mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added shuffeling to prompts
This commit is contained in:
@@ -184,6 +184,10 @@ config:
|
|||||||
# if you are doing more than one target it may be good to set less important ones
|
# if you are doing more than one target it may be good to set less important ones
|
||||||
# to a lower number like 0.1 so they don't outweigh the primary target
|
# to a lower number like 0.1 so they don't outweigh the primary target
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
|
# shuffle the prompts split by the comma. We will run every combination randomly
|
||||||
|
# this will make the LoRA more robust. You probably want this on unless prompt order
|
||||||
|
# is important for some reason
|
||||||
|
shuffle: true
|
||||||
|
|
||||||
|
|
||||||
# anchors are prompts that we will try to hold on to while training the slider
|
# anchors are prompts that we will try to hold on to while training the slider
|
||||||
|
|||||||
@@ -86,6 +86,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# remove duplicates
|
# remove duplicates
|
||||||
prompts_to_cache = list(dict.fromkeys(prompts_to_cache))
|
prompts_to_cache = list(dict.fromkeys(prompts_to_cache))
|
||||||
|
|
||||||
|
# trim to max steps if max steps is lower than prompt count
|
||||||
|
prompts_to_cache = prompts_to_cache[:self.train_config.steps]
|
||||||
|
|
||||||
# encode them
|
# encode them
|
||||||
cache = encode_prompts_to_cache(
|
cache = encode_prompts_to_cache(
|
||||||
prompt_list=prompts_to_cache,
|
prompt_list=prompts_to_cache,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import List, Optional
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SaveConfig:
|
class SaveConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.save_every: int = kwargs.get('save_every', 1000)
|
self.save_every: int = kwargs.get('save_every', 1000)
|
||||||
@@ -93,6 +94,7 @@ class SliderTargetConfig:
|
|||||||
self.negative: str = kwargs.get('negative', '')
|
self.negative: str = kwargs.get('negative', '')
|
||||||
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
||||||
self.weight: float = kwargs.get('weight', 1.0)
|
self.weight: float = kwargs.get('weight', 1.0)
|
||||||
|
self.shuffle: bool = kwargs.get('shuffle', False)
|
||||||
|
|
||||||
|
|
||||||
class SliderConfigAnchors:
|
class SliderConfigAnchors:
|
||||||
@@ -105,8 +107,6 @@ class SliderConfigAnchors:
|
|||||||
class SliderConfig:
|
class SliderConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
targets = kwargs.get('targets', [])
|
targets = kwargs.get('targets', [])
|
||||||
targets = [SliderTargetConfig(**target) for target in targets]
|
|
||||||
self.targets: List[SliderTargetConfig] = targets
|
|
||||||
anchors = kwargs.get('anchors', [])
|
anchors = kwargs.get('anchors', [])
|
||||||
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
|
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
|
||||||
self.anchors: List[SliderConfigAnchors] = anchors
|
self.anchors: List[SliderConfigAnchors] = anchors
|
||||||
@@ -115,6 +115,18 @@ class SliderConfig:
|
|||||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||||
|
|
||||||
|
# expand targets if shuffling
|
||||||
|
from toolkit.prompt_utils import get_slider_target_permutations
|
||||||
|
self.targets: List[SliderTargetConfig] = []
|
||||||
|
targets = [SliderTargetConfig(**target) for target in targets]
|
||||||
|
# do permutations if shuffle is true
|
||||||
|
for target in targets:
|
||||||
|
if target.shuffle:
|
||||||
|
target_permutations = get_slider_target_permutations(target)
|
||||||
|
self.targets = self.targets + target_permutations
|
||||||
|
else:
|
||||||
|
self.targets.append(target)
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageConfig:
|
class GenerateImageConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.config_modules import SliderTargetConfig
|
||||||
|
|
||||||
|
|
||||||
class ACTION_TYPES_SLIDER:
|
class ACTION_TYPES_SLIDER:
|
||||||
@@ -226,6 +230,40 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
|
|||||||
return anchors
|
return anchors
|
||||||
|
|
||||||
|
|
||||||
|
def get_permutations(s):
|
||||||
|
# Split the string by comma
|
||||||
|
phrases = [phrase.strip() for phrase in s.split(',')]
|
||||||
|
|
||||||
|
# remove empty strings
|
||||||
|
phrases = [phrase for phrase in phrases if len(phrase) > 0]
|
||||||
|
|
||||||
|
# Get all permutations
|
||||||
|
permutations = list(itertools.permutations(phrases))
|
||||||
|
|
||||||
|
# Convert the tuples back to comma separated strings
|
||||||
|
return [', '.join(permutation) for permutation in permutations]
|
||||||
|
|
||||||
|
|
||||||
|
def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']:
|
||||||
|
from toolkit.config_modules import SliderTargetConfig
|
||||||
|
pos_permutations = get_permutations(target.positive)
|
||||||
|
neg_permutations = get_permutations(target.negative)
|
||||||
|
|
||||||
|
permutations = []
|
||||||
|
for pos, neg in itertools.product(pos_permutations, neg_permutations):
|
||||||
|
permutations.append(
|
||||||
|
SliderTargetConfig(
|
||||||
|
target_class=target.target_class,
|
||||||
|
positive=pos,
|
||||||
|
negative=neg,
|
||||||
|
multiplier=target.multiplier,
|
||||||
|
weight=target.weight
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return permutations
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
@@ -291,10 +329,6 @@ def encode_prompts_to_cache(
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from toolkit.config_modules import SliderTargetConfig
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def build_prompt_pair_batch_from_cache(
|
def build_prompt_pair_batch_from_cache(
|
||||||
cache: PromptEmbedsCache,
|
cache: PromptEmbedsCache,
|
||||||
|
|||||||
Reference in New Issue
Block a user