Moved some of the job config into base process so it will be easier to extend extensions

This commit is contained in:
Jaret Burkett
2023-08-10 12:14:05 -06:00
parent fbc8a87a05
commit df48f0a843
6 changed files with 43 additions and 30 deletions

View File

@@ -56,6 +56,7 @@ class EncodedPromptPair:
# simulate torch to for tensors
def to(self, *args, **kwargs):
self.target_class = self.target_class.to(*args, **kwargs)
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
self.positive_target = self.positive_target.to(*args, **kwargs)
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
self.negative_target = self.negative_target.to(*args, **kwargs)
@@ -308,7 +309,7 @@ def build_prompt_pair_batch_from_cache(
prompt_pair_batch = []
if both or erase_negative:
print("Encoding erase negative")
# print("Encoding erase negative")
prompt_pair_batch += [
# erase standard
EncodedPromptPair(
@@ -327,7 +328,7 @@ def build_prompt_pair_batch_from_cache(
),
]
if both or enhance_positive:
print("Encoding enhance positive")
# print("Encoding enhance positive")
prompt_pair_batch += [
# enhance standard, swap pos neg
EncodedPromptPair(
@@ -346,7 +347,7 @@ def build_prompt_pair_batch_from_cache(
),
]
if both or enhance_positive:
print("Encoding erase positive (inverse)")
# print("Encoding erase positive (inverse)")
prompt_pair_batch += [
# erase inverted
EncodedPromptPair(
@@ -365,7 +366,7 @@ def build_prompt_pair_batch_from_cache(
),
]
if both or erase_negative:
print("Encoding enhance negative (inverse)")
# print("Encoding enhance negative (inverse)")
prompt_pair_batch += [
# enhance inverted
EncodedPromptPair(