mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added Model rescale and prepared a release upgrade
This commit is contained in:
@@ -46,8 +46,8 @@ class EncodedPromptPair:
|
||||
negative_target,
|
||||
negative_target_with_neutral,
|
||||
neutral,
|
||||
both_targets,
|
||||
empty_prompt,
|
||||
both_targets,
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=1.0,
|
||||
weight=1.0
|
||||
@@ -123,23 +123,24 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
|
||||
|
||||
# read line by line from file
|
||||
with open(self.slider_config.prompt_file, 'r') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
if self.slider_config.prompt_file:
|
||||
with open(self.slider_config.prompt_file, 'r') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
|
||||
|
||||
if not self.slider_config.prompt_tensors:
|
||||
# shuffle
|
||||
random.shuffle(self.prompt_txt_list)
|
||||
# trim to max steps
|
||||
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
|
||||
# trim list to our max steps
|
||||
|
||||
cache = PromptEmbedsCache()
|
||||
|
||||
if not self.slider_config.prompt_tensors:
|
||||
# shuffle
|
||||
random.shuffle(self.prompt_txt_list)
|
||||
# trim to max steps
|
||||
self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
|
||||
# trim list to our max steps
|
||||
|
||||
|
||||
# get encoded latents for our prompts
|
||||
with torch.no_grad():
|
||||
if self.slider_config.prompt_tensors is not None:
|
||||
@@ -169,7 +170,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# encode empty_prompt
|
||||
cache[empty_prompt] = self.sd.encode_prompt(empty_prompt)
|
||||
|
||||
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||
neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
|
||||
|
||||
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
|
||||
for target in self.slider_config.targets:
|
||||
prompt_list = [
|
||||
f"{target.target_class}", # target_class
|
||||
@@ -212,10 +215,15 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
save_file(state_dict, self.slider_config.prompt_tensors)
|
||||
|
||||
prompt_pairs = []
|
||||
for neutral in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||
for neutral in tqdm(neutral_list, desc="Encoding prompts", leave=False):
|
||||
for target in self.slider_config.targets:
|
||||
erase_negative = len(target.positive.strip()) == 0
|
||||
enhance_positive = len(target.negative.strip()) == 0
|
||||
|
||||
both = not erase_negative and not enhance_positive
|
||||
|
||||
if both or erase_negative:
|
||||
print("Encoding erase negative")
|
||||
prompt_pairs += [
|
||||
# erase standard
|
||||
EncodedPromptPair(
|
||||
@@ -234,6 +242,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
print("Encoding enhance positive")
|
||||
prompt_pairs += [
|
||||
# enhance standard, swap pos neg
|
||||
EncodedPromptPair(
|
||||
@@ -251,7 +260,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
if both or enhance_positive:
|
||||
# if both or enhance_positive:
|
||||
if both:
|
||||
print("Encoding erase positive (inverse)")
|
||||
prompt_pairs += [
|
||||
# erase inverted
|
||||
EncodedPromptPair(
|
||||
@@ -269,7 +280,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
weight=target.weight
|
||||
),
|
||||
]
|
||||
if both or erase_negative:
|
||||
# if both or erase_negative:
|
||||
if both:
|
||||
print("Encoding enhance negative (inverse)")
|
||||
prompt_pairs += [
|
||||
# enhance inverted
|
||||
EncodedPromptPair(
|
||||
@@ -341,10 +354,6 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
|
||||
]
|
||||
|
||||
target_class = prompt_pair.target_class
|
||||
neutral = prompt_pair.neutral
|
||||
negative = prompt_pair.negative_target
|
||||
positive = prompt_pair.positive_target
|
||||
weight = prompt_pair.weight
|
||||
multiplier = prompt_pair.multiplier
|
||||
|
||||
|
||||
Reference in New Issue
Block a user