Various features and fixes. Too much brain fog to do a proper description

This commit is contained in:
Jaret Burkett
2024-07-18 07:34:14 -06:00
parent 58dffd43a8
commit 11e426fdf1
6 changed files with 119 additions and 25 deletions

View File

@@ -160,6 +160,8 @@ class AdapterConfig:
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
if self.train_only_image_encoder:
self.train_image_encoder = True
self.train_only_image_encoder_positional_embedding: bool = kwargs.get(
'train_only_image_encoder_positional_embedding', False)
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
self.safe_channels: int = kwargs.get('safe_channels', 2048)
@@ -260,6 +262,7 @@ class TrainConfig:
# multiplier applied to loos on regularization images
self.reg_weight = kwargs.get('reg_weight', 1.0)
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
# dropout that happens before encoding. It functions independently per text encoder
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
@@ -385,6 +388,11 @@ class ModelConfig:
self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
self.unet_path = kwargs.get("unet_path", None)
self.unet_sample_size = kwargs.get("unet_sample_size", None)
self.vae_device = kwargs.get("vae_device", None)
self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
self.te_device = kwargs.get("te_device", None)
self.te_dtype = kwargs.get("te_dtype", self.dtype)
pass
class EMAConfig: