Bug fixes

This commit is contained in:
Jaret Burkett
2023-07-29 13:39:57 -06:00
parent 2305e55c82
commit 9cdf2dd6e4
2 changed files with 12 additions and 8 deletions

View File

@@ -282,12 +282,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
# prepare meta # prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name) save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None: if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
# TODO handle dreambooth, fine tuning, etc # TODO handle dreambooth, fine tuning, etc
self.network.save_weights( self.network.save_weights(
file_path, file_path,
dtype=get_torch_dtype(self.save_config.dtype), dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta metadata=save_meta
) )
self.network.multiplier = prev_multiplier
else: else:
self.sd.save( self.sd.save(
file_path, file_path,
@@ -639,7 +642,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ### ### HOOK ###
loss_dict = self.hook_train_loop() loss_dict = self.hook_train_loop()
if self.train_config.optimizer.startswith('dadaptation'): if self.train_config.optimizer.lower().startswith('dadaptation'):
learning_rate = ( learning_rate = (
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"] optimizer.param_groups[0]["lr"]

View File

@@ -115,12 +115,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
cache[prompt] = self.sd.encode_prompt(prompt) cache[prompt] = self.sd.encode_prompt(prompt)
for resolution in self.slider_config.resolutions: for resolution in self.slider_config.resolutions:
width, height = resolution width, height = resolution
only_erase = len(target.positive.strip()) == 0 erase_negative = len(target.positive.strip()) == 0
only_enhance = len(target.negative.strip()) == 0 enhance_positive = len(target.negative.strip()) == 0
both = not only_erase and not only_enhance both = not erase_negative and not enhance_positive
if only_erase and only_enhance: if erase_negative and enhance_positive:
raise ValueError("target must have at least one of positive or negative or both") raise ValueError("target must have at least one of positive or negative or both")
# for slider we need to have an enhancer, an eraser, and then # for slider we need to have an enhancer, an eraser, and then
# an inverse with negative weights to balance the network # an inverse with negative weights to balance the network
@@ -128,7 +128,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# we only perform actions of enhancing and erasing on the negative # we only perform actions of enhancing and erasing on the negative
# todo work on way to do all of this in one shot # todo work on way to do all of this in one shot
if both or only_erase: if both or erase_negative:
prompt_pairs += [ prompt_pairs += [
# erase standard # erase standard
EncodedPromptPair( EncodedPromptPair(
@@ -143,7 +143,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight weight=target.weight
), ),
] ]
if both or only_enhance: if both or enhance_positive:
prompt_pairs += [ prompt_pairs += [
# enhance standard, swap pos neg # enhance standard, swap pos neg
EncodedPromptPair( EncodedPromptPair(
@@ -158,7 +158,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight weight=target.weight
), ),
] ]
if both: if both or enhance_positive:
prompt_pairs += [ prompt_pairs += [
# erase inverted # erase inverted
EncodedPromptPair( EncodedPromptPair(
@@ -173,6 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
weight=target.weight weight=target.weight
), ),
] ]
if both or erase_negative:
prompt_pairs += [ prompt_pairs += [
# enhance inverted # enhance inverted
EncodedPromptPair( EncodedPromptPair(