diff --git a/extensions_built_in/concept_slider/ConceptSliderTrainer.py b/extensions_built_in/concept_slider/ConceptSliderTrainer.py index eba1d74e..6c9fce74 100644 --- a/extensions_built_in/concept_slider/ConceptSliderTrainer.py +++ b/extensions_built_in/concept_slider/ConceptSliderTrainer.py @@ -19,6 +19,18 @@ class ConceptSliderTrainerConfig: self.anchor_class: Optional[str] = kwargs.get("anchor_class", None) +def norm_like_tensor(tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Normalize the tensor to have the same mean and std as the target tensor.""" + tensor_mean = tensor.mean() + tensor_std = tensor.std() + target_mean = target.mean() + target_std = target.std() + normalized_tensor = (tensor - tensor_mean) / ( + tensor_std + 1e-8 + ) * target_std + target_mean + return normalized_tensor + + class ConceptSliderTrainer(DiffusionTrainer): def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): super().__init__(process_id, job, config, **kwargs) @@ -100,7 +112,6 @@ class ConceptSliderTrainer(DiffusionTrainer): if self.network is not None: was_network_active = self.network.is_active self.network.is_active = False - # do out prior preds first with torch.no_grad(): @@ -175,7 +186,7 @@ class ConceptSliderTrainer(DiffusionTrainer): # erase_positive_target = neutral_pred - guidance_scale * ( # positive_pred - negative_pred # ) - + positive = (positive_pred - neutral_pred) - (negative_pred - neutral_pred) negative = (negative_pred - neutral_pred) - (positive_pred - neutral_pred) @@ -183,6 +194,20 @@ class ConceptSliderTrainer(DiffusionTrainer): enhance_negative_target = neutral_pred + guidance_scale * negative erase_negative_target = neutral_pred - guidance_scale * negative erase_positive_target = neutral_pred - guidance_scale * positive + + # normalize to neutral std/mean + enhance_positive_target = norm_like_tensor( + enhance_positive_target, neutral_pred + ) + enhance_negative_target = norm_like_tensor( + enhance_negative_target, neutral_pred + ) + erase_negative_target = norm_like_tensor( + erase_negative_target, neutral_pred + ) + erase_positive_target = norm_like_tensor( + erase_positive_target, neutral_pred + ) if was_unet_training: self.sd.unet.train() @@ -230,7 +255,7 @@ class ConceptSliderTrainer(DiffusionTrainer): anchor_loss = torch.zeros_like(erase_loss) else: anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target) - + anchor_loss = anchor_loss * self.slider.anchor_strength # send backward now because gradient checkpointing needs network polarity intact diff --git a/version.py b/version.py index 63c762ca..dd1e6dbf 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.5.9" \ No newline at end of file +VERSION = "0.5.10" \ No newline at end of file