diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index f9b810fe..c5e6d71e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -541,6 +541,7 @@ class BaseSDTrainProcess(BaseTrainProcess): conv_alpha=self.network_config.conv_alpha, is_sdxl=self.model_config.is_xl, is_v2=self.model_config.is_v2, + dropout=self.network_config.dropout ) self.network.force_to(self.device_torch, dtype=dtype) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 6c7e7895..466b2b97 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,6 +1,6 @@ import os import time -from typing import List, Optional, Literal +from typing import List, Optional, Literal, Union import random ImgExt = Literal['jpg', 'png', 'webp'] @@ -55,6 +55,7 @@ class NetworkConfig: self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) self.normalize = kwargs.get('normalize', False) + self.dropout: Union[float, None] = kwargs.get('dropout', None) class EmbeddingConfig: