Ultimate slider training built, still needs tuning

This commit is contained in:
Jaret Burkett
2023-08-19 18:54:34 -06:00
parent b77b9acc0b
commit bef5551ea5
2 changed files with 202 additions and 16 deletions

View File

@@ -34,7 +34,8 @@ class EncodedPromptPair:
action_list=None,
multiplier=1.0,
multiplier_list=None,
weight=1.0
weight=1.0,
target: 'SliderTargetConfig' = None,
):
self.target_class: PromptEmbeds = target_class
self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral
@@ -46,6 +47,7 @@ class EncodedPromptPair:
self.empty_prompt: PromptEmbeds = empty_prompt
self.both_targets: PromptEmbeds = both_targets
self.multiplier: float = multiplier
self.target: 'SliderTargetConfig' = target
if multiplier_list is not None:
self.multiplier_list: list[float] = multiplier_list
else:
@@ -109,7 +111,8 @@ def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
both_targets=both_targets,
action_list=action_list,
multiplier_list=multiplier_list,
weight=weight
weight=weight,
target=prompt_pairs[0].target
)
@@ -160,7 +163,8 @@ def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List
both_targets=both_targets_splits[i],
action_list=action_list_split,
multiplier_list=multiplier_list_split,
weight=concatenated.weight
weight=concatenated.weight,
target=concatenated.target
)
prompt_pairs.append(prompt_pair)
@@ -358,7 +362,8 @@ def build_prompt_pair_batch_from_cache(
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
weight=target.weight,
target=target
),
]
if both or enhance_positive:
@@ -377,7 +382,8 @@ def build_prompt_pair_batch_from_cache(
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
weight=target.weight,
target=target
),
]
if both or enhance_positive:
@@ -396,7 +402,8 @@ def build_prompt_pair_batch_from_cache(
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
weight=target.weight,
target=target
),
]
if both or erase_negative:
@@ -415,8 +422,39 @@ def build_prompt_pair_batch_from_cache(
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
weight=target.weight,
target=target
),
]
return prompt_pair_batch
def build_latent_image_batch_for_prompt_pair(
pos_latent,
neg_latent,
prompt_pair: EncodedPromptPair,
prompt_chunk_size
):
erase_negative = len(prompt_pair.target.positive.strip()) == 0
enhance_positive = len(prompt_pair.target.negative.strip()) == 0
both = not erase_negative and not enhance_positive
prompt_pair_chunks = split_prompt_pairs(prompt_pair, prompt_chunk_size)
if both and len(prompt_pair_chunks) != 4:
raise Exception("Invalid prompt pair chunks")
if (erase_negative or enhance_positive) and len(prompt_pair_chunks) != 2:
raise Exception("Invalid prompt pair chunks")
latent_list = []
if both or erase_negative:
latent_list.append(pos_latent)
if both or enhance_positive:
latent_list.append(pos_latent)
if both or enhance_positive:
latent_list.append(neg_latent)
if both or erase_negative:
latent_list.append(neg_latent)
return torch.cat(latent_list, dim=0)