mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge pull request #426 from squewel/prior_reg
Dataset-level prior regularization
This commit is contained in:
@@ -813,6 +813,7 @@ class DatasetConfig:
|
||||
self.buckets: bool = kwargs.get('buckets', True)
|
||||
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
|
||||
self.is_reg: bool = kwargs.get('is_reg', False)
|
||||
self.prior_reg: bool = kwargs.get('prior_reg', False)
|
||||
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
|
||||
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
|
||||
@@ -121,6 +121,7 @@ class FileItemDTO(
|
||||
|
||||
self.network_weight: float = self.dataset_config.network_weight
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
self.prior_reg = self.dataset_config.prior_reg
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
def cleanup(self):
|
||||
|
||||
Reference in New Issue
Block a user