Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.

This commit is contained in:
Jaret Burkett
2023-11-03 05:01:16 -06:00
parent ceaf1d9454
commit d35733ac06
8 changed files with 3569 additions and 75 deletions

View File

@@ -2,11 +2,13 @@ import copy
import glob
import inspect
import json
import shutil
from collections import OrderedDict
import os
from typing import Union, List
import numpy as np
import yaml
from diffusers import T2IAdapter
from safetensors.torch import save_file, load_file
# from lycoris.config import PRESET
@@ -36,7 +38,8 @@ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \
parse_metadata_from_safetensors
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
import gc
@@ -265,24 +268,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
def clean_up_saves(self):
# remove old saves
# get latest saved step
latest_item = None
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
pattern = f"{self.job.name}_*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > self.save_config.max_step_saves_to_keep:
# remove all but the latest max_step_saves_to_keep
files.sort(key=os.path.getctime)
for file in files[:-self.save_config.max_step_saves_to_keep]:
self.print(f"Removing old save: {file}")
os.remove(file)
# see if a yaml file with same name exists
yaml_file = os.path.splitext(file)[0] + ".yaml"
if os.path.exists(yaml_file):
os.remove(yaml_file)
return latest_file
else:
return None
# pattern is {job_name}_{zero_filled_step} for both files and directories
pattern = f"{self.job.name}_*"
items = glob.glob(os.path.join(self.save_root, pattern))
# Separate files and directories
safetensors_files = [f for f in items if f.endswith('.safetensors')]
directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
# Combine the list and sort by creation time
combined_items = safetensors_files + directories
combined_items.sort(key=os.path.getctime)
# remove all but the latest max_step_saves_to_keep
items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
for item in items_to_remove:
self.print(f"Removing old save: {item}")
if os.path.isdir(item):
shutil.rmtree(item)
else:
os.remove(item)
# see if a yaml file with same name exists
yaml_file = os.path.splitext(item)[0] + ".yaml"
if os.path.exists(yaml_file):
os.remove(yaml_file)
if combined_items:
latest_item = combined_items[-1]
return latest_item
def post_save_hook(self, save_path):
# override in subclass
@@ -366,6 +377,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
if self.save_config.save_format == "diffusers":
# saving as a folder path
file_path = file_path.replace('.safetensors', '')
# convert it back to normal object
save_meta = parse_metadata_from_safetensors(save_meta)
self.sd.save(
file_path,
save_meta,
@@ -415,24 +431,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
if name == None:
name = self.job.name
# get latest saved step
latest_path = None
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{name}*{post}.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
# try pt
pattern = f"{name}*.pt"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
return latest_file
else:
return None
# Define patterns for both files and directories
patterns = [
f"{name}*{post}.safetensors",
f"{name}*{post}.pt",
f"{name}*{post}"
]
# Search for both files and directories
paths = []
for pattern in patterns:
paths.extend(glob.glob(os.path.join(self.save_root, pattern)))
# Filter out non-existent paths and sort by creation time
if paths:
paths = [p for p in paths if os.path.exists(p)]
latest_path = max(paths, key=os.path.getctime)
return latest_path
def load_training_state_from_metadata(self, path):
meta = load_metadata_from_safetensors(path)
# if path is folder, then it is diffusers
if os.path.isdir(path):
meta_path = os.path.join(path, 'aitk_meta.yaml')
# load it
with open(meta_path, 'r') as f:
meta = yaml.load(f, Loader=yaml.FullLoader)
else:
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
self.step_num = meta['training_info']['step']
@@ -731,14 +758,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
self.load_training_state_from_metadata(latest_save_path)
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)