mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added a way to add a t2i adapter guided slider training for more consitant images
This commit is contained in:
@@ -1,8 +1,17 @@
|
|||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from diffusers import T2IAdapter
|
||||||
|
from torchvision.transforms import transforms
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from toolkit.basic import value_map
|
||||||
from toolkit.config_modules import SliderConfig
|
from toolkit.config_modules import SliderConfig
|
||||||
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||||
|
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||||
import gc
|
import gc
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
@@ -21,6 +30,10 @@ def flush():
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
adapter_transforms = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
class TrainSliderProcess(BaseSDTrainProcess):
|
class TrainSliderProcess(BaseSDTrainProcess):
|
||||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||||
super().__init__(process_id, job, config)
|
super().__init__(process_id, job, config)
|
||||||
@@ -42,6 +55,27 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# trim targets
|
# trim targets
|
||||||
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
|
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
|
||||||
|
|
||||||
|
# get presets
|
||||||
|
self.eval_slider_device_state = get_train_sd_device_state_preset(
|
||||||
|
self.device_torch,
|
||||||
|
train_unet=False,
|
||||||
|
train_text_encoder=False,
|
||||||
|
cached_latents=self.is_latents_cached,
|
||||||
|
train_lora=False,
|
||||||
|
train_adapter=False,
|
||||||
|
train_embedding=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.train_slider_device_state = get_train_sd_device_state_preset(
|
||||||
|
self.device_torch,
|
||||||
|
train_unet=self.train_config.train_unet,
|
||||||
|
train_text_encoder=False,
|
||||||
|
cached_latents=self.is_latents_cached,
|
||||||
|
train_lora=True,
|
||||||
|
train_adapter=False,
|
||||||
|
train_embedding=False,
|
||||||
|
)
|
||||||
|
|
||||||
def before_model_load(self):
|
def before_model_load(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -66,6 +100,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# trim list to our max steps
|
# trim list to our max steps
|
||||||
|
|
||||||
cache = PromptEmbedsCache()
|
cache = PromptEmbedsCache()
|
||||||
|
print(f"Building prompt cache")
|
||||||
|
|
||||||
# get encoded latents for our prompts
|
# get encoded latents for our prompts
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -175,30 +210,95 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
self.sd.vae.to(self.device_torch)
|
self.sd.vae.to(self.device_torch)
|
||||||
# end hook_before_train_loop
|
# end hook_before_train_loop
|
||||||
|
|
||||||
|
def before_dataset_load(self):
|
||||||
|
if self.slider_config.use_adapter == 'depth':
|
||||||
|
print(f"Loading T2I Adapter for depth")
|
||||||
|
# called before LoRA network is loaded but after model is loaded
|
||||||
|
# attach the adapter here so it is there before we load the network
|
||||||
|
adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
|
||||||
|
if self.sd.is_xl:
|
||||||
|
adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
|
||||||
|
|
||||||
|
# dont name this adapter since we are not training it
|
||||||
|
self.t2i_adapter = T2IAdapter.from_pretrained(
|
||||||
|
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
|
||||||
|
).to(self.device_torch)
|
||||||
|
self.t2i_adapter.eval()
|
||||||
|
self.t2i_adapter.requires_grad_(False)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']):
|
||||||
|
|
||||||
|
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||||
|
adapter_folder_path = self.slider_config.adapter_img_dir
|
||||||
|
adapter_images = []
|
||||||
|
# loop through images
|
||||||
|
for file_item in batch.file_items:
|
||||||
|
img_path = file_item.path
|
||||||
|
file_name_no_ext = os.path.basename(img_path).split('.')[0]
|
||||||
|
# find the image
|
||||||
|
for ext in img_ext_list:
|
||||||
|
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
|
||||||
|
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
|
||||||
|
break
|
||||||
|
width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
|
||||||
|
adapter_tensors = []
|
||||||
|
# load images with torch transforms
|
||||||
|
for idx, adapter_image in enumerate(adapter_images):
|
||||||
|
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
|
||||||
|
# to the smallest dimension
|
||||||
|
img: Image.Image = Image.open(adapter_image)
|
||||||
|
if img.width > img.height:
|
||||||
|
# scale down so height is the same as batch
|
||||||
|
new_height = height
|
||||||
|
new_width = int(img.width * (height / img.height))
|
||||||
|
else:
|
||||||
|
new_width = width
|
||||||
|
new_height = int(img.height * (width / img.width))
|
||||||
|
|
||||||
|
img = img.resize((new_width, new_height))
|
||||||
|
crop_fn = transforms.CenterCrop((height, width))
|
||||||
|
# crop the center to match batch
|
||||||
|
img = crop_fn(img)
|
||||||
|
img = adapter_transforms(img)
|
||||||
|
adapter_tensors.append(img)
|
||||||
|
|
||||||
|
# stack them
|
||||||
|
adapter_tensors = torch.stack(adapter_tensors).to(
|
||||||
|
self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
|
||||||
|
)
|
||||||
|
return adapter_tensors
|
||||||
|
|
||||||
def hook_train_loop(self, batch):
|
def hook_train_loop(self, batch):
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
# set to eval mode
|
||||||
|
self.sd.set_device_state(self.eval_slider_device_state)
|
||||||
|
with torch.no_grad():
|
||||||
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
|
# get a random pair
|
||||||
|
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||||
|
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||||
|
]
|
||||||
|
# move to device and dtype
|
||||||
|
prompt_pair.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
# get a random pair
|
# get a random resolution
|
||||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
height, width = self.slider_config.resolutions[
|
||||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
||||||
]
|
]
|
||||||
# move to device and dtype
|
if self.train_config.gradient_checkpointing:
|
||||||
prompt_pair.to(self.device_torch, dtype=dtype)
|
# may get disabled elsewhere
|
||||||
|
self.sd.unet.enable_gradient_checkpointing()
|
||||||
# get a random resolution
|
|
||||||
height, width = self.slider_config.resolutions[
|
|
||||||
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
|
||||||
]
|
|
||||||
if self.train_config.gradient_checkpointing:
|
|
||||||
# may get disabled elsewhere
|
|
||||||
self.sd.unet.enable_gradient_checkpointing()
|
|
||||||
|
|
||||||
noise_scheduler = self.sd.noise_scheduler
|
noise_scheduler = self.sd.noise_scheduler
|
||||||
optimizer = self.optimizer
|
optimizer = self.optimizer
|
||||||
lr_scheduler = self.lr_scheduler
|
lr_scheduler = self.lr_scheduler
|
||||||
|
|
||||||
loss_function = torch.nn.MSELoss()
|
loss_function = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
pred_kwargs = {}
|
||||||
|
|
||||||
def get_noise_pred(neg, pos, gs, cts, dn):
|
def get_noise_pred(neg, pos, gs, cts, dn):
|
||||||
return self.sd.predict_noise(
|
return self.sd.predict_noise(
|
||||||
latents=dn,
|
latents=dn,
|
||||||
@@ -209,9 +309,11 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
),
|
),
|
||||||
timestep=cts,
|
timestep=cts,
|
||||||
guidance_scale=gs,
|
guidance_scale=gs,
|
||||||
|
**pred_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
adapter_images = None
|
||||||
# for a complete slider, the batch size is 4 to begin with now
|
# for a complete slider, the batch size is 4 to begin with now
|
||||||
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
|
||||||
from_batch = False
|
from_batch = False
|
||||||
@@ -219,9 +321,32 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# traing from a batch of images, not generating ourselves
|
# traing from a batch of images, not generating ourselves
|
||||||
from_batch = True
|
from_batch = True
|
||||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||||
|
if self.slider_config.adapter_img_dir is not None:
|
||||||
|
adapter_images = self.get_adapter_images(batch)
|
||||||
|
adapter_strength_min = 0.9
|
||||||
|
adapter_strength_max = 1.0
|
||||||
|
|
||||||
denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size
|
def rand_strength(sample):
|
||||||
denoised_latents = torch.cat(denoised_latent_chunks, dim=0)
|
adapter_conditioning_scale = torch.rand(
|
||||||
|
(1,), device=self.device_torch, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_conditioning_scale = value_map(
|
||||||
|
adapter_conditioning_scale,
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
adapter_strength_min,
|
||||||
|
adapter_strength_max
|
||||||
|
)
|
||||||
|
return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale
|
||||||
|
|
||||||
|
down_block_additional_residuals = self.t2i_adapter(adapter_images)
|
||||||
|
down_block_additional_residuals = [
|
||||||
|
rand_strength(sample) for sample in down_block_additional_residuals
|
||||||
|
]
|
||||||
|
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||||
|
|
||||||
|
denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
|
||||||
current_timestep = timesteps
|
current_timestep = timesteps
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@@ -229,8 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
self.train_config.max_denoising_steps, device=self.device_torch
|
self.train_config.max_denoising_steps, device=self.device_torch
|
||||||
)
|
)
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
# ger a random number of steps
|
# ger a random number of steps
|
||||||
timesteps_to = torch.randint(
|
timesteps_to = torch.randint(
|
||||||
1, self.train_config.max_denoising_steps, (1,)
|
1, self.train_config.max_denoising_steps, (1,)
|
||||||
@@ -267,13 +390,14 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
|
|
||||||
noise_scheduler.set_timesteps(1000)
|
noise_scheduler.set_timesteps(1000)
|
||||||
|
|
||||||
# split the latents into out prompt pair chunks
|
|
||||||
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
|
|
||||||
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
|
|
||||||
|
|
||||||
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||||
|
|
||||||
|
# split the latents into out prompt pair chunks
|
||||||
|
denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
|
||||||
|
denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
|
||||||
|
|
||||||
# flush() # 4.2GB to 3GB on 512x512
|
# flush() # 4.2GB to 3GB on 512x512
|
||||||
|
|
||||||
# 4.20 GB RAM for 512x512
|
# 4.20 GB RAM for 512x512
|
||||||
@@ -286,7 +410,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
)
|
)
|
||||||
positive_latents = positive_latents.detach()
|
positive_latents = positive_latents.detach()
|
||||||
positive_latents.requires_grad = False
|
positive_latents.requires_grad = False
|
||||||
positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0)
|
|
||||||
|
|
||||||
neutral_latents = get_noise_pred(
|
neutral_latents = get_noise_pred(
|
||||||
prompt_pair.positive_target, # negative prompt
|
prompt_pair.positive_target, # negative prompt
|
||||||
@@ -297,7 +420,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
)
|
)
|
||||||
neutral_latents = neutral_latents.detach()
|
neutral_latents = neutral_latents.detach()
|
||||||
neutral_latents.requires_grad = False
|
neutral_latents.requires_grad = False
|
||||||
neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0)
|
|
||||||
|
|
||||||
unconditional_latents = get_noise_pred(
|
unconditional_latents = get_noise_pred(
|
||||||
prompt_pair.positive_target, # negative prompt
|
prompt_pair.positive_target, # negative prompt
|
||||||
@@ -308,13 +430,13 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
)
|
)
|
||||||
unconditional_latents = unconditional_latents.detach()
|
unconditional_latents = unconditional_latents.detach()
|
||||||
unconditional_latents.requires_grad = False
|
unconditional_latents.requires_grad = False
|
||||||
unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0)
|
|
||||||
|
|
||||||
denoised_latents = denoised_latents.detach()
|
denoised_latents = denoised_latents.detach()
|
||||||
|
|
||||||
# flush() # 4.2GB to 3GB on 512x512
|
self.sd.set_device_state(self.train_slider_device_state)
|
||||||
|
# start accumulating gradients
|
||||||
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# 4.20 GB RAM for 512x512
|
|
||||||
anchor_loss_float = None
|
anchor_loss_float = None
|
||||||
if len(self.anchor_pairs) > 0:
|
if len(self.anchor_pairs) > 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -369,9 +491,23 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
del anchor_target_noise
|
del anchor_target_noise
|
||||||
# move anchor back to cpu
|
# move anchor back to cpu
|
||||||
anchor.to("cpu")
|
anchor.to("cpu")
|
||||||
# flush()
|
|
||||||
|
|
||||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
|
with torch.no_grad():
|
||||||
|
if self.slider_config.high_ram:
|
||||||
|
# run through in one instance
|
||||||
|
prompt_pair_chunks = [prompt_pair.detach()]
|
||||||
|
denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
|
||||||
|
positive_latents_chunks = [positive_latents.detach()]
|
||||||
|
neutral_latents_chunks = [neutral_latents.detach()]
|
||||||
|
unconditional_latents_chunks = [unconditional_latents.detach()]
|
||||||
|
else:
|
||||||
|
prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
|
||||||
|
denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
|
||||||
|
positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||||
|
neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||||
|
unconditional_latents_chunks = torch.chunk(unconditional_latents.detach(), self.prompt_chunk_size, dim=0)
|
||||||
|
|
||||||
|
# flush()
|
||||||
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
|
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
|
||||||
# 3.28 GB RAM for 512x512
|
# 3.28 GB RAM for 512x512
|
||||||
with self.network:
|
with self.network:
|
||||||
|
|||||||
@@ -186,18 +186,23 @@ class SliderConfig:
|
|||||||
self.prompt_file: str = kwargs.get('prompt_file', None)
|
self.prompt_file: str = kwargs.get('prompt_file', None)
|
||||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||||
|
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
|
||||||
|
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
|
||||||
|
self.high_ram = kwargs.get('high_ram', False)
|
||||||
|
|
||||||
# expand targets if shuffling
|
# expand targets if shuffling
|
||||||
from toolkit.prompt_utils import get_slider_target_permutations
|
from toolkit.prompt_utils import get_slider_target_permutations
|
||||||
self.targets: List[SliderTargetConfig] = []
|
self.targets: List[SliderTargetConfig] = []
|
||||||
targets = [SliderTargetConfig(**target) for target in targets]
|
targets = [SliderTargetConfig(**target) for target in targets]
|
||||||
# do permutations if shuffle is true
|
# do permutations if shuffle is true
|
||||||
|
print(f"Building slider targets")
|
||||||
for target in targets:
|
for target in targets:
|
||||||
if target.shuffle:
|
if target.shuffle:
|
||||||
target_permutations = get_slider_target_permutations(target)
|
target_permutations = get_slider_target_permutations(target, max_permutations=100)
|
||||||
self.targets = self.targets + target_permutations
|
self.targets = self.targets + target_permutations
|
||||||
else:
|
else:
|
||||||
self.targets.append(target)
|
self.targets.append(target)
|
||||||
|
print(f"Built {len(self.targets)} slider targets (with permutations)")
|
||||||
|
|
||||||
|
|
||||||
class DatasetConfig:
|
class DatasetConfig:
|
||||||
|
|||||||
@@ -105,6 +105,18 @@ class EncodedPromptPair:
|
|||||||
self.both_targets = self.both_targets.to(*args, **kwargs)
|
self.both_targets = self.both_targets.to(*args, **kwargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
self.target_class = self.target_class.detach()
|
||||||
|
self.target_class_with_neutral = self.target_class_with_neutral.detach()
|
||||||
|
self.positive_target = self.positive_target.detach()
|
||||||
|
self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
|
||||||
|
self.negative_target = self.negative_target.detach()
|
||||||
|
self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
|
||||||
|
self.neutral = self.neutral.detach()
|
||||||
|
self.empty_prompt = self.empty_prompt.detach()
|
||||||
|
self.both_targets = self.both_targets.detach()
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
|
||||||
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
|
text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
|
||||||
@@ -267,15 +279,17 @@ def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[Enc
|
|||||||
return anchors
|
return anchors
|
||||||
|
|
||||||
|
|
||||||
def get_permutations(s):
|
def get_permutations(s, max_permutations=8):
|
||||||
# Split the string by comma
|
# Split the string by comma
|
||||||
phrases = [phrase.strip() for phrase in s.split(',')]
|
phrases = [phrase.strip() for phrase in s.split(',')]
|
||||||
|
|
||||||
# remove empty strings
|
# remove empty strings
|
||||||
phrases = [phrase for phrase in phrases if len(phrase) > 0]
|
phrases = [phrase for phrase in phrases if len(phrase) > 0]
|
||||||
|
# shuffle the list
|
||||||
|
random.shuffle(phrases)
|
||||||
|
|
||||||
# Get all permutations
|
# Get all permutations
|
||||||
permutations = list(itertools.permutations(phrases))
|
permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
|
||||||
|
|
||||||
# Convert the tuples back to comma separated strings
|
# Convert the tuples back to comma separated strings
|
||||||
return [', '.join(permutation) for permutation in permutations]
|
return [', '.join(permutation) for permutation in permutations]
|
||||||
@@ -283,8 +297,8 @@ def get_permutations(s):
|
|||||||
|
|
||||||
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
|
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
|
||||||
from toolkit.config_modules import SliderTargetConfig
|
from toolkit.config_modules import SliderTargetConfig
|
||||||
pos_permutations = get_permutations(target.positive)
|
pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
|
||||||
neg_permutations = get_permutations(target.negative)
|
neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
|
||||||
|
|
||||||
permutations = []
|
permutations = []
|
||||||
for pos, neg in itertools.product(pos_permutations, neg_permutations):
|
for pos, neg in itertools.product(pos_permutations, neg_permutations):
|
||||||
|
|||||||
Reference in New Issue
Block a user