mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Allow augmentations and targeting different loss types fron the config file
This commit is contained in:
@@ -88,7 +88,7 @@ class EmbeddingConfig:
|
||||
|
||||
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
|
||||
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -127,6 +127,7 @@ class TrainConfig:
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented,
|
||||
|
||||
# legacy
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
@@ -258,7 +259,13 @@ class DatasetConfig:
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
|
||||
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk):
|
||||
# https://albumentations.ai/docs/api_reference/augmentations/transforms
|
||||
# augmentations are returned as a separate image and cannot currently be cached
|
||||
self.augmentations: List[dict] = kwargs.get('augmentations', None)
|
||||
|
||||
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
|
||||
|
||||
if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
|
||||
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
|
||||
self.cache_latents = False
|
||||
self.cache_latents_to_disk = False
|
||||
|
||||
Reference in New Issue
Block a user