Regularize the slider targets.

This commit is contained in:
Jaret Burkett
2025-09-17 09:36:33 -06:00
parent 218f673e3d
commit 24a576ad07
2 changed files with 29 additions and 4 deletions

View File

@@ -19,6 +19,18 @@ class ConceptSliderTrainerConfig:
self.anchor_class: Optional[str] = kwargs.get("anchor_class", None)
def norm_like_tensor(tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Normalize the tensor to have the same mean and std as the target tensor."""
tensor_mean = tensor.mean()
tensor_std = tensor.std()
target_mean = target.mean()
target_std = target.std()
normalized_tensor = (tensor - tensor_mean) / (
tensor_std + 1e-8
) * target_std + target_mean
return normalized_tensor
class ConceptSliderTrainer(DiffusionTrainer):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
@@ -100,7 +112,6 @@ class ConceptSliderTrainer(DiffusionTrainer):
if self.network is not None:
was_network_active = self.network.is_active
self.network.is_active = False
# do out prior preds first
with torch.no_grad():
@@ -175,7 +186,7 @@ class ConceptSliderTrainer(DiffusionTrainer):
# erase_positive_target = neutral_pred - guidance_scale * (
# positive_pred - negative_pred
# )
positive = (positive_pred - neutral_pred) - (negative_pred - neutral_pred)
negative = (negative_pred - neutral_pred) - (positive_pred - neutral_pred)
@@ -183,6 +194,20 @@ class ConceptSliderTrainer(DiffusionTrainer):
enhance_negative_target = neutral_pred + guidance_scale * negative
erase_negative_target = neutral_pred - guidance_scale * negative
erase_positive_target = neutral_pred - guidance_scale * positive
# normalize to neutral std/mean
enhance_positive_target = norm_like_tensor(
enhance_positive_target, neutral_pred
)
enhance_negative_target = norm_like_tensor(
enhance_negative_target, neutral_pred
)
erase_negative_target = norm_like_tensor(
erase_negative_target, neutral_pred
)
erase_positive_target = norm_like_tensor(
erase_positive_target, neutral_pred
)
if was_unet_training:
self.sd.unet.train()
@@ -230,7 +255,7 @@ class ConceptSliderTrainer(DiffusionTrainer):
anchor_loss = torch.zeros_like(erase_loss)
else:
anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target)
anchor_loss = anchor_loss * self.slider.anchor_strength
# send backward now because gradient checkpointing needs network polarity intact

View File

@@ -1 +1 @@
VERSION = "0.5.9"
VERSION = "0.5.10"