mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Regularize the slider targets.
This commit is contained in:
@@ -19,6 +19,18 @@ class ConceptSliderTrainerConfig:
|
||||
self.anchor_class: Optional[str] = kwargs.get("anchor_class", None)
|
||||
|
||||
|
||||
def norm_like_tensor(tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize the tensor to have the same mean and std as the target tensor."""
|
||||
tensor_mean = tensor.mean()
|
||||
tensor_std = tensor.std()
|
||||
target_mean = target.mean()
|
||||
target_std = target.std()
|
||||
normalized_tensor = (tensor - tensor_mean) / (
|
||||
tensor_std + 1e-8
|
||||
) * target_std + target_mean
|
||||
return normalized_tensor
|
||||
|
||||
|
||||
class ConceptSliderTrainer(DiffusionTrainer):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
@@ -101,7 +113,6 @@ class ConceptSliderTrainer(DiffusionTrainer):
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
|
||||
|
||||
# do out prior preds first
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -184,6 +195,20 @@ class ConceptSliderTrainer(DiffusionTrainer):
|
||||
erase_negative_target = neutral_pred - guidance_scale * negative
|
||||
erase_positive_target = neutral_pred - guidance_scale * positive
|
||||
|
||||
# normalize to neutral std/mean
|
||||
enhance_positive_target = norm_like_tensor(
|
||||
enhance_positive_target, neutral_pred
|
||||
)
|
||||
enhance_negative_target = norm_like_tensor(
|
||||
enhance_negative_target, neutral_pred
|
||||
)
|
||||
erase_negative_target = norm_like_tensor(
|
||||
erase_negative_target, neutral_pred
|
||||
)
|
||||
erase_positive_target = norm_like_tensor(
|
||||
erase_positive_target, neutral_pred
|
||||
)
|
||||
|
||||
if was_unet_training:
|
||||
self.sd.unet.train()
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.5.9"
|
||||
VERSION = "0.5.10"
|
||||
Reference in New Issue
Block a user