mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Ultimate slider training built, still needs tuning
This commit is contained in:
@@ -34,7 +34,8 @@ class EncodedPromptPair:
|
||||
action_list=None,
|
||||
multiplier=1.0,
|
||||
multiplier_list=None,
|
||||
weight=1.0
|
||||
weight=1.0,
|
||||
target: 'SliderTargetConfig' = None,
|
||||
):
|
||||
self.target_class: PromptEmbeds = target_class
|
||||
self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral
|
||||
@@ -46,6 +47,7 @@ class EncodedPromptPair:
|
||||
self.empty_prompt: PromptEmbeds = empty_prompt
|
||||
self.both_targets: PromptEmbeds = both_targets
|
||||
self.multiplier: float = multiplier
|
||||
self.target: 'SliderTargetConfig' = target
|
||||
if multiplier_list is not None:
|
||||
self.multiplier_list: list[float] = multiplier_list
|
||||
else:
|
||||
@@ -109,7 +111,8 @@ def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
|
||||
both_targets=both_targets,
|
||||
action_list=action_list,
|
||||
multiplier_list=multiplier_list,
|
||||
weight=weight
|
||||
weight=weight,
|
||||
target=prompt_pairs[0].target
|
||||
)
|
||||
|
||||
|
||||
@@ -160,7 +163,8 @@ def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List
|
||||
both_targets=both_targets_splits[i],
|
||||
action_list=action_list_split,
|
||||
multiplier_list=multiplier_list_split,
|
||||
weight=concatenated.weight
|
||||
weight=concatenated.weight,
|
||||
target=concatenated.target
|
||||
)
|
||||
prompt_pairs.append(prompt_pair)
|
||||
|
||||
@@ -358,7 +362,8 @@ def build_prompt_pair_batch_from_cache(
|
||||
multiplier=target.multiplier,
|
||||
both_targets=cache[f"{target.positive} {target.negative}"],
|
||||
empty_prompt=cache[""],
|
||||
weight=target.weight
|
||||
weight=target.weight,
|
||||
target=target
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
@@ -377,7 +382,8 @@ def build_prompt_pair_batch_from_cache(
|
||||
multiplier=target.multiplier,
|
||||
both_targets=cache[f"{target.positive} {target.negative}"],
|
||||
empty_prompt=cache[""],
|
||||
weight=target.weight
|
||||
weight=target.weight,
|
||||
target=target
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
@@ -396,7 +402,8 @@ def build_prompt_pair_batch_from_cache(
|
||||
both_targets=cache[f"{target.positive} {target.negative}"],
|
||||
empty_prompt=cache[""],
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
weight=target.weight,
|
||||
target=target
|
||||
),
|
||||
]
|
||||
if both or erase_negative:
|
||||
@@ -415,8 +422,39 @@ def build_prompt_pair_batch_from_cache(
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
empty_prompt=cache[""],
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
weight=target.weight,
|
||||
target=target
|
||||
),
|
||||
]
|
||||
|
||||
return prompt_pair_batch
|
||||
|
||||
|
||||
def build_latent_image_batch_for_prompt_pair(
|
||||
pos_latent,
|
||||
neg_latent,
|
||||
prompt_pair: EncodedPromptPair,
|
||||
prompt_chunk_size
|
||||
):
|
||||
erase_negative = len(prompt_pair.target.positive.strip()) == 0
|
||||
enhance_positive = len(prompt_pair.target.negative.strip()) == 0
|
||||
both = not erase_negative and not enhance_positive
|
||||
|
||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair, prompt_chunk_size)
|
||||
if both and len(prompt_pair_chunks) != 4:
|
||||
raise Exception("Invalid prompt pair chunks")
|
||||
if (erase_negative or enhance_positive) and len(prompt_pair_chunks) != 2:
|
||||
raise Exception("Invalid prompt pair chunks")
|
||||
|
||||
latent_list = []
|
||||
|
||||
if both or erase_negative:
|
||||
latent_list.append(pos_latent)
|
||||
if both or enhance_positive:
|
||||
latent_list.append(pos_latent)
|
||||
if both or enhance_positive:
|
||||
latent_list.append(neg_latent)
|
||||
if both or erase_negative:
|
||||
latent_list.append(neg_latent)
|
||||
|
||||
return torch.cat(latent_list, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user