mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-23 15:59:32 +00:00
Bug fixes
This commit is contained in:
@@ -282,12 +282,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# prepare meta
|
# prepare meta
|
||||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
|
prev_multiplier = self.network.multiplier
|
||||||
|
self.network.multiplier = 1.0
|
||||||
# TODO handle dreambooth, fine tuning, etc
|
# TODO handle dreambooth, fine tuning, etc
|
||||||
self.network.save_weights(
|
self.network.save_weights(
|
||||||
file_path,
|
file_path,
|
||||||
dtype=get_torch_dtype(self.save_config.dtype),
|
dtype=get_torch_dtype(self.save_config.dtype),
|
||||||
metadata=save_meta
|
metadata=save_meta
|
||||||
)
|
)
|
||||||
|
self.network.multiplier = prev_multiplier
|
||||||
else:
|
else:
|
||||||
self.sd.save(
|
self.sd.save(
|
||||||
file_path,
|
file_path,
|
||||||
@@ -639,7 +642,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
### HOOK ###
|
### HOOK ###
|
||||||
loss_dict = self.hook_train_loop()
|
loss_dict = self.hook_train_loop()
|
||||||
|
|
||||||
if self.train_config.optimizer.startswith('dadaptation'):
|
if self.train_config.optimizer.lower().startswith('dadaptation'):
|
||||||
learning_rate = (
|
learning_rate = (
|
||||||
optimizer.param_groups[0]["d"] *
|
optimizer.param_groups[0]["d"] *
|
||||||
optimizer.param_groups[0]["lr"]
|
optimizer.param_groups[0]["lr"]
|
||||||
|
|||||||
@@ -115,12 +115,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
cache[prompt] = self.sd.encode_prompt(prompt)
|
cache[prompt] = self.sd.encode_prompt(prompt)
|
||||||
for resolution in self.slider_config.resolutions:
|
for resolution in self.slider_config.resolutions:
|
||||||
width, height = resolution
|
width, height = resolution
|
||||||
only_erase = len(target.positive.strip()) == 0
|
erase_negative = len(target.positive.strip()) == 0
|
||||||
only_enhance = len(target.negative.strip()) == 0
|
enhance_positive = len(target.negative.strip()) == 0
|
||||||
|
|
||||||
both = not only_erase and not only_enhance
|
both = not erase_negative and not enhance_positive
|
||||||
|
|
||||||
if only_erase and only_enhance:
|
if erase_negative and enhance_positive:
|
||||||
raise ValueError("target must have at least one of positive or negative or both")
|
raise ValueError("target must have at least one of positive or negative or both")
|
||||||
# for slider we need to have an enhancer, an eraser, and then
|
# for slider we need to have an enhancer, an eraser, and then
|
||||||
# an inverse with negative weights to balance the network
|
# an inverse with negative weights to balance the network
|
||||||
@@ -128,7 +128,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
# we only perform actions of enhancing and erasing on the negative
|
# we only perform actions of enhancing and erasing on the negative
|
||||||
# todo work on way to do all of this in one shot
|
# todo work on way to do all of this in one shot
|
||||||
|
|
||||||
if both or only_erase:
|
if both or erase_negative:
|
||||||
prompt_pairs += [
|
prompt_pairs += [
|
||||||
# erase standard
|
# erase standard
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -143,7 +143,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
weight=target.weight
|
weight=target.weight
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
if both or only_enhance:
|
if both or enhance_positive:
|
||||||
prompt_pairs += [
|
prompt_pairs += [
|
||||||
# enhance standard, swap pos neg
|
# enhance standard, swap pos neg
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -158,7 +158,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
weight=target.weight
|
weight=target.weight
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
if both:
|
if both or enhance_positive:
|
||||||
prompt_pairs += [
|
prompt_pairs += [
|
||||||
# erase inverted
|
# erase inverted
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -173,6 +173,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
weight=target.weight
|
weight=target.weight
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
if both or erase_negative:
|
||||||
prompt_pairs += [
|
prompt_pairs += [
|
||||||
# enhance inverted
|
# enhance inverted
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
|
|||||||
Reference in New Issue
Block a user