mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-01 11:29:46 +00:00
added dropout to LoRA networks
This commit is contained in:
@@ -541,6 +541,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
conv_alpha=self.network_config.conv_alpha,
|
||||
is_sdxl=self.model_config.is_xl,
|
||||
is_v2=self.model_config.is_v2,
|
||||
dropout=self.network_config.dropout
|
||||
)
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Literal
|
||||
from typing import List, Optional, Literal, Union
|
||||
import random
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
@@ -55,6 +55,7 @@ class NetworkConfig:
|
||||
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
|
||||
self.normalize = kwargs.get('normalize', False)
|
||||
self.dropout: Union[float, None] = kwargs.get('dropout', None)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
|
||||
Reference in New Issue
Block a user