Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.

This commit is contained in:
Jaret Burkett
2023-11-03 05:01:16 -06:00
parent ceaf1d9454
commit d35733ac06
8 changed files with 3569 additions and 75 deletions

View File

@@ -9,12 +9,17 @@ from toolkit.prompt_utils import PromptEmbeds
ImgExt = Literal['jpg', 'png', 'webp']
SaveFormat = Literal['safetensors', 'diffusers']
class SaveConfig:
def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000)
self.dtype: str = kwargs.get('save_dtype', 'float16')
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
if self.save_format not in ['safetensors', 'diffusers']:
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
class LogingConfig:
@@ -187,7 +192,7 @@ class TrainConfig:
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
# match the norm of the noise before computing loss. This will help the model maintain its
#current understandin of the brightness of images.
# current understandin of the brightness of images.
self.match_noise_norm = kwargs.get('match_noise_norm', False)
@@ -229,6 +234,7 @@ class ModelConfig:
self.name_or_path: str = kwargs.get('name_or_path', None)
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_ssd: bool = kwargs.get('is_ssd', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16')
self.vae_path = kwargs.get('vae_path', None)
@@ -242,6 +248,10 @@ class ModelConfig:
if self.name_or_path is None:
raise ValueError('name_or_path must be specified')
if self.is_ssd:
# sed sdxl as true since it is mostly the same architecture
self.is_xl = True
class ReferenceDatasetConfig:
def __init__(self, **kwargs):