mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Moved some of the job config into base process so it will be easier to extend extensions
This commit is contained in:
@@ -56,6 +56,7 @@ class EncodedPromptPair:
|
||||
# simulate torch to for tensors
|
||||
def to(self, *args, **kwargs):
|
||||
self.target_class = self.target_class.to(*args, **kwargs)
|
||||
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
|
||||
self.positive_target = self.positive_target.to(*args, **kwargs)
|
||||
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
|
||||
self.negative_target = self.negative_target.to(*args, **kwargs)
|
||||
@@ -308,7 +309,7 @@ def build_prompt_pair_batch_from_cache(
|
||||
prompt_pair_batch = []
|
||||
|
||||
if both or erase_negative:
|
||||
print("Encoding erase negative")
|
||||
# print("Encoding erase negative")
|
||||
prompt_pair_batch += [
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
@@ -327,7 +328,7 @@ def build_prompt_pair_batch_from_cache(
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
print("Encoding enhance positive")
|
||||
# print("Encoding enhance positive")
|
||||
prompt_pair_batch += [
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
@@ -346,7 +347,7 @@ def build_prompt_pair_batch_from_cache(
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
print("Encoding erase positive (inverse)")
|
||||
# print("Encoding erase positive (inverse)")
|
||||
prompt_pair_batch += [
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
@@ -365,7 +366,7 @@ def build_prompt_pair_batch_from_cache(
|
||||
),
|
||||
]
|
||||
if both or erase_negative:
|
||||
print("Encoding enhance negative (inverse)")
|
||||
# print("Encoding enhance negative (inverse)")
|
||||
prompt_pair_batch += [
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
|
||||
Reference in New Issue
Block a user