add dataset-level distillation-style regularization

This commit is contained in:
max
2025-09-18 01:11:19 +03:00
parent b95c17dc17
commit e4ae97e790
2 changed files with 6 additions and 5 deletions

View File

@@ -806,6 +806,7 @@ class DatasetConfig:
self.buckets: bool = kwargs.get('buckets', True)
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
self.is_reg: bool = kwargs.get('is_reg', False)
self.prior_reg: bool = kwargs.get('prior_reg', False)
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)