mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added experimental wavelet loss
This commit is contained in:
@@ -405,7 +405,7 @@ class TrainConfig:
|
||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse')
|
||||
self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet
|
||||
|
||||
# scale the prediction by this. Increase for more detail, decrease for less
|
||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||
|
||||
Reference in New Issue
Block a user