Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs

This commit is contained in:
Jaret Burkett
2024-01-28 08:20:03 -07:00
parent f17ad8d794
commit 92b9c71d44
10 changed files with 352 additions and 56 deletions

View File

@@ -113,6 +113,7 @@ class NetworkConfig:
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
self.dropout: Union[float, None] = kwargs.get('dropout', None)
self.network_kwargs: dict = kwargs.get('network_kwargs', {})
self.lorm_config: Union[LoRMConfig, None] = None
lorm = kwargs.get('lorm', None)
@@ -153,10 +154,14 @@ class AdapterConfig:
self.num_tokens: int = num_tokens
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
if self.train_only_image_encoder:
self.train_image_encoder = True
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
self.safe_channels: int = kwargs.get('safe_channels', 2048)
self.safe_tokens: int = kwargs.get('safe_tokens', 8)
self.quad_image: bool = kwargs.get('quad_image', False)
# clip vision
self.trigger = kwargs.get('trigger', 'tri993r')
@@ -211,6 +216,7 @@ class TrainConfig:
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.force_first_sample = kwargs.get('force_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
@@ -275,7 +281,9 @@ class TrainConfig:
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
self.do_cfg = kwargs.get('do_cfg', False)
self.do_random_cfg = kwargs.get('do_random_cfg', False)
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
class ModelConfig: