mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Regularize the slider targets.
This commit is contained in:
@@ -19,6 +19,18 @@ class ConceptSliderTrainerConfig:
|
||||
self.anchor_class: Optional[str] = kwargs.get("anchor_class", None)
|
||||
|
||||
|
||||
def norm_like_tensor(tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize the tensor to have the same mean and std as the target tensor."""
|
||||
tensor_mean = tensor.mean()
|
||||
tensor_std = tensor.std()
|
||||
target_mean = target.mean()
|
||||
target_std = target.std()
|
||||
normalized_tensor = (tensor - tensor_mean) / (
|
||||
tensor_std + 1e-8
|
||||
) * target_std + target_mean
|
||||
return normalized_tensor
|
||||
|
||||
|
||||
class ConceptSliderTrainer(DiffusionTrainer):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
@@ -100,7 +112,6 @@ class ConceptSliderTrainer(DiffusionTrainer):
|
||||
if self.network is not None:
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
|
||||
|
||||
# do out prior preds first
|
||||
with torch.no_grad():
|
||||
@@ -175,7 +186,7 @@ class ConceptSliderTrainer(DiffusionTrainer):
|
||||
# erase_positive_target = neutral_pred - guidance_scale * (
|
||||
# positive_pred - negative_pred
|
||||
# )
|
||||
|
||||
|
||||
positive = (positive_pred - neutral_pred) - (negative_pred - neutral_pred)
|
||||
negative = (negative_pred - neutral_pred) - (positive_pred - neutral_pred)
|
||||
|
||||
@@ -183,6 +194,20 @@ class ConceptSliderTrainer(DiffusionTrainer):
|
||||
enhance_negative_target = neutral_pred + guidance_scale * negative
|
||||
erase_negative_target = neutral_pred - guidance_scale * negative
|
||||
erase_positive_target = neutral_pred - guidance_scale * positive
|
||||
|
||||
# normalize to neutral std/mean
|
||||
enhance_positive_target = norm_like_tensor(
|
||||
enhance_positive_target, neutral_pred
|
||||
)
|
||||
enhance_negative_target = norm_like_tensor(
|
||||
enhance_negative_target, neutral_pred
|
||||
)
|
||||
erase_negative_target = norm_like_tensor(
|
||||
erase_negative_target, neutral_pred
|
||||
)
|
||||
erase_positive_target = norm_like_tensor(
|
||||
erase_positive_target, neutral_pred
|
||||
)
|
||||
|
||||
if was_unet_training:
|
||||
self.sd.unet.train()
|
||||
@@ -230,7 +255,7 @@ class ConceptSliderTrainer(DiffusionTrainer):
|
||||
anchor_loss = torch.zeros_like(erase_loss)
|
||||
else:
|
||||
anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target)
|
||||
|
||||
|
||||
anchor_loss = anchor_loss * self.slider.anchor_strength
|
||||
|
||||
# send backward now because gradient checkpointing needs network polarity intact
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.5.9"
|
||||
VERSION = "0.5.10"
|
||||
Reference in New Issue
Block a user