mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-28 08:13:58 +00:00
Change img multiplier math
This commit is contained in:
@@ -457,7 +457,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.target_norm_std:
|
||||
# seperate out the batch and channels
|
||||
pred_std = noise_pred.std([2, 3], keepdim=True)
|
||||
norm_std_loss = torch.abs(1.0 - pred_std).mean()
|
||||
norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean()
|
||||
loss = loss + norm_std_loss
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.ema import ExponentialMovingAverage
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.image_utils import show_tensors, show_latents, reduce_contrast
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
|
||||
@@ -811,7 +811,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.img_multiplier is not None:
|
||||
imgs = imgs * self.train_config.img_multiplier
|
||||
# do it ad contrast
|
||||
imgs = reduce_contrast(imgs, self.train_config.img_multiplier)
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
|
||||
@@ -354,6 +354,7 @@ class TrainConfig:
|
||||
|
||||
# adds an additional loss to the network to encourage it output a normalized standard deviation
|
||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
@@ -495,6 +495,19 @@ def on_exit():
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def reduce_contrast(tensor, factor):
|
||||
# Ensure factor is between 0 and 1
|
||||
factor = max(0, min(factor, 1))
|
||||
|
||||
# Calculate the mean of the tensor
|
||||
mean = torch.mean(tensor)
|
||||
|
||||
# Reduce contrast
|
||||
adjusted_tensor = (tensor - mean) * factor + mean
|
||||
|
||||
# Clip values to ensure they stay within -1 to 1 range
|
||||
return torch.clamp(adjusted_tensor, -1.0, 1.0)
|
||||
|
||||
atexit.register(on_exit)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user