mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.
This commit is contained in:
@@ -9,12 +9,17 @@ from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
|
||||
SaveFormat = Literal['safetensors', 'diffusers']
|
||||
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.save_every: int = kwargs.get('save_every', 1000)
|
||||
self.dtype: str = kwargs.get('save_dtype', 'float16')
|
||||
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
|
||||
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
|
||||
if self.save_format not in ['safetensors', 'diffusers']:
|
||||
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
|
||||
|
||||
|
||||
class LogingConfig:
|
||||
@@ -187,7 +192,7 @@ class TrainConfig:
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
|
||||
# match the norm of the noise before computing loss. This will help the model maintain its
|
||||
#current understandin of the brightness of images.
|
||||
# current understandin of the brightness of images.
|
||||
|
||||
self.match_noise_norm = kwargs.get('match_noise_norm', False)
|
||||
|
||||
@@ -229,6 +234,7 @@ class ModelConfig:
|
||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||
self.is_v2: bool = kwargs.get('is_v2', False)
|
||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||
self.vae_path = kwargs.get('vae_path', None)
|
||||
@@ -242,6 +248,10 @@ class ModelConfig:
|
||||
if self.name_or_path is None:
|
||||
raise ValueError('name_or_path must be specified')
|
||||
|
||||
if self.is_ssd:
|
||||
# sed sdxl as true since it is mostly the same architecture
|
||||
self.is_xl = True
|
||||
|
||||
|
||||
class ReferenceDatasetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user