Fixes for aug pipeline

This commit is contained in:
Jaret Burkett
2023-10-20 13:24:07 -06:00
parent 07bf7bd7de
commit d46112a354
3 changed files with 21 additions and 5 deletions

View File

@@ -126,9 +126,13 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss.mean()
return loss
def hook_train_loop(self, batch):
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()

View File

@@ -127,7 +127,7 @@ class TrainConfig:
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented,
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0:
@@ -262,6 +262,7 @@ class DatasetConfig:
# https://albumentations.ai/docs/api_reference/augmentations/transforms
# augmentations are returned as a separate image and cannot currently be cached
self.augmentations: List[dict] = kwargs.get('augmentations', None)
self.shuffle_augmentations: bool = kwargs.get('shuffle_augmentations', False)
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0

View File

@@ -496,10 +496,17 @@ class AugmentationFileItemDTOMixin:
self.has_augmentations = False
self.unaugmented_tensor: Union[torch.Tensor, None] = None
# self.augmentations: Union[None, List[Augments]] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.augmentations is not None and len(dataset_config.augmentations) > 0:
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.build_augmentation_transform()
def build_augmentation_transform(self: 'FileItemDTO'):
if self.dataset_config.augmentations is not None and len(self.dataset_config.augmentations) > 0:
self.has_augmentations = True
augmentations = [Augments(**aug) for aug in dataset_config.augmentations]
augmentations = [Augments(**aug) for aug in self.dataset_config.augmentations]
if self.dataset_config.shuffle_augmentations:
random.shuffle(augmentations)
augmentation_list = []
for aug in augmentations:
# make sure method name is valid
@@ -513,6 +520,10 @@ class AugmentationFileItemDTOMixin:
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
# rebuild each time if shuffle
if self.dataset_config.shuffle_augmentations:
self.build_augmentation_transform()
# save the original tensor
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)