mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-24 16:29:26 +00:00
Fixes for aug pipeline
This commit is contained in:
@@ -126,9 +126,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
return batch
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
|
||||
@@ -127,7 +127,7 @@ class TrainConfig:
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented,
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise
|
||||
|
||||
# legacy
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
@@ -262,6 +262,7 @@ class DatasetConfig:
|
||||
# https://albumentations.ai/docs/api_reference/augmentations/transforms
|
||||
# augmentations are returned as a separate image and cannot currently be cached
|
||||
self.augmentations: List[dict] = kwargs.get('augmentations', None)
|
||||
self.shuffle_augmentations: bool = kwargs.get('shuffle_augmentations', False)
|
||||
|
||||
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
|
||||
|
||||
|
||||
@@ -496,10 +496,17 @@ class AugmentationFileItemDTOMixin:
|
||||
self.has_augmentations = False
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
# self.augmentations: Union[None, List[Augments]] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.augmentations is not None and len(dataset_config.augmentations) > 0:
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.build_augmentation_transform()
|
||||
|
||||
def build_augmentation_transform(self: 'FileItemDTO'):
|
||||
if self.dataset_config.augmentations is not None and len(self.dataset_config.augmentations) > 0:
|
||||
self.has_augmentations = True
|
||||
augmentations = [Augments(**aug) for aug in dataset_config.augmentations]
|
||||
augmentations = [Augments(**aug) for aug in self.dataset_config.augmentations]
|
||||
|
||||
if self.dataset_config.shuffle_augmentations:
|
||||
random.shuffle(augmentations)
|
||||
|
||||
augmentation_list = []
|
||||
for aug in augmentations:
|
||||
# make sure method name is valid
|
||||
@@ -513,6 +520,10 @@ class AugmentationFileItemDTOMixin:
|
||||
|
||||
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
|
||||
|
||||
# rebuild each time if shuffle
|
||||
if self.dataset_config.shuffle_augmentations:
|
||||
self.build_augmentation_transform()
|
||||
|
||||
# save the original tensor
|
||||
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user