Change img multiplier math

This commit is contained in:
Jaret Burkett
2024-07-30 11:33:41 -06:00
parent 443c996e7f
commit 47744373f2
4 changed files with 18 additions and 3 deletions

View File

@@ -457,7 +457,7 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.target_norm_std:
# seperate out the batch and channels
pred_std = noise_pred.std([2, 3], keepdim=True)
norm_std_loss = torch.abs(1.0 - pred_std).mean()
norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean()
loss = loss + norm_std_loss

View File

@@ -24,7 +24,7 @@ from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.ema import ExponentialMovingAverage
from toolkit.embedding import Embedding
from toolkit.image_utils import show_tensors, show_latents
from toolkit.image_utils import show_tensors, show_latents, reduce_contrast
from toolkit.ip_adapter import IPAdapter
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
@@ -811,7 +811,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
if self.train_config.img_multiplier is not None:
imgs = imgs * self.train_config.img_multiplier
# do it ad contrast
imgs = reduce_contrast(imgs, self.train_config.img_multiplier)
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents

View File

@@ -354,6 +354,7 @@ class TrainConfig:
# adds an additional loss to the network to encourage it output a normalized standard deviation
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
class ModelConfig:

View File

@@ -495,6 +495,19 @@ def on_exit():
cv2.destroyAllWindows()
def reduce_contrast(tensor, factor):
# Ensure factor is between 0 and 1
factor = max(0, min(factor, 1))
# Calculate the mean of the tensor
mean = torch.mean(tensor)
# Reduce contrast
adjusted_tensor = (tensor - mean) * factor + mean
# Clip values to ensure they stay within -1 to 1 range
return torch.clamp(adjusted_tensor, -1.0, 1.0)
atexit.register(on_exit)
if __name__ == "__main__":