Actually use the save dtype from the config file.

This commit is contained in:
Jaret Burkett
2024-08-13 17:08:27 -06:00
parent f7cf2f866f
commit 00bd3d54a3

View File

@@ -18,7 +18,7 @@ if TYPE_CHECKING:
class SaveConfig:
def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000)
self.dtype: str = kwargs.get('save_dtype', 'float16')
self.dtype: str = kwargs.get('dtype', 'float16')
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
if self.save_format not in ['safetensors', 'diffusers']: