Added Model rescale and prepared a release upgrade

This commit is contained in:
Jaret Burkett
2023-08-01 13:49:54 -06:00
parent 63cacf4362
commit 8b8d53888d
15 changed files with 388 additions and 64 deletions

View File

@@ -46,8 +46,8 @@ class EncodedPromptPair:
negative_target,
negative_target_with_neutral,
neutral,
both_targets,
empty_prompt,
both_targets,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0,
weight=1.0
@@ -123,23 +123,24 @@ class TrainSliderProcess(BaseSDTrainProcess):
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
# read line by line from file
with open(self.slider_config.prompt_file, 'r') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
if self.slider_config.prompt_file:
with open(self.slider_config.prompt_file, 'r') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
if not self.slider_config.prompt_tensors:
# shuffle
random.shuffle(self.prompt_txt_list)
# trim to max steps
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
# trim list to our max steps
cache = PromptEmbedsCache()
if not self.slider_config.prompt_tensors:
# shuffle
random.shuffle(self.prompt_txt_list)
# trim to max steps
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
# trim list to our max steps
# get encoded latents for our prompts
with torch.no_grad():
if self.slider_config.prompt_tensors is not None:
@@ -169,7 +170,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
# encode empty_prompt
cache[empty_prompt] = self.sd.encode_prompt(empty_prompt)
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
for target in self.slider_config.targets:
prompt_list = [
f"{target.target_class}", # target_class
@@ -212,10 +215,15 @@ class TrainSliderProcess(BaseSDTrainProcess):
save_file(state_dict, self.slider_config.prompt_tensors)
prompt_pairs = []
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
for target in self.slider_config.targets:
erase_negative = len(target.positive.strip()) == 0
enhance_positive = len(target.negative.strip()) == 0
both = not erase_negative and not enhance_positive
if both or erase_negative:
print("Encoding erase negative")
prompt_pairs += [
# erase standard
EncodedPromptPair(
@@ -234,6 +242,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
),
]
if both or enhance_positive:
print("Encoding enhance positive")
prompt_pairs += [
# enhance standard, swap pos neg
EncodedPromptPair(
@@ -251,7 +260,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight
),
]
if both or enhance_positive:
# if both or enhance_positive:
if both:
print("Encoding erase positive (inverse)")
prompt_pairs += [
# erase inverted
EncodedPromptPair(
@@ -269,7 +280,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight
),
]
if both or erase_negative:
# if both or erase_negative:
if both:
print("Encoding enhance negative (inverse)")
prompt_pairs += [
# enhance inverted
EncodedPromptPair(
@@ -341,10 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
]
target_class = prompt_pair.target_class
neutral = prompt_pair.neutral
negative = prompt_pair.negative_target
positive = prompt_pair.positive_target
weight = prompt_pair.weight
multiplier = prompt_pair.multiplier