mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Various features and fixes. Too much brain fog to do a proper description
This commit is contained in:
@@ -160,6 +160,8 @@ class AdapterConfig:
|
||||
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
|
||||
if self.train_only_image_encoder:
|
||||
self.train_image_encoder = True
|
||||
self.train_only_image_encoder_positional_embedding: bool = kwargs.get(
|
||||
'train_only_image_encoder_positional_embedding', False)
|
||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
||||
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
||||
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
||||
@@ -260,6 +262,7 @@ class TrainConfig:
|
||||
# multiplier applied to loos on regularization images
|
||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||
|
||||
# dropout that happens before encoding. It functions independently per text encoder
|
||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||
@@ -385,6 +388,11 @@ class ModelConfig:
|
||||
self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
|
||||
self.unet_path = kwargs.get("unet_path", None)
|
||||
self.unet_sample_size = kwargs.get("unet_sample_size", None)
|
||||
self.vae_device = kwargs.get("vae_device", None)
|
||||
self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
|
||||
self.te_device = kwargs.get("te_device", None)
|
||||
self.te_dtype = kwargs.get("te_dtype", self.dtype)
|
||||
pass
|
||||
|
||||
|
||||
class EMAConfig:
|
||||
|
||||
Reference in New Issue
Block a user