mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.
This commit is contained in:
@@ -2,11 +2,13 @@ import copy
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from diffusers import T2IAdapter
|
||||
from safetensors.torch import save_file, load_file
|
||||
# from lycoris.config import PRESET
|
||||
@@ -36,7 +38,8 @@ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \
|
||||
parse_metadata_from_safetensors
|
||||
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
|
||||
import gc
|
||||
|
||||
@@ -265,24 +268,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def clean_up_saves(self):
|
||||
# remove old saves
|
||||
# get latest saved step
|
||||
latest_item = None
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
|
||||
pattern = f"{self.job.name}_*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > self.save_config.max_step_saves_to_keep:
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
files.sort(key=os.path.getctime)
|
||||
for file in files[:-self.save_config.max_step_saves_to_keep]:
|
||||
self.print(f"Removing old save: {file}")
|
||||
os.remove(file)
|
||||
# see if a yaml file with same name exists
|
||||
yaml_file = os.path.splitext(file)[0] + ".yaml"
|
||||
if os.path.exists(yaml_file):
|
||||
os.remove(yaml_file)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
# pattern is {job_name}_{zero_filled_step} for both files and directories
|
||||
pattern = f"{self.job.name}_*"
|
||||
items = glob.glob(os.path.join(self.save_root, pattern))
|
||||
# Separate files and directories
|
||||
safetensors_files = [f for f in items if f.endswith('.safetensors')]
|
||||
directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
|
||||
# Combine the list and sort by creation time
|
||||
combined_items = safetensors_files + directories
|
||||
combined_items.sort(key=os.path.getctime)
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
for item in items_to_remove:
|
||||
self.print(f"Removing old save: {item}")
|
||||
if os.path.isdir(item):
|
||||
shutil.rmtree(item)
|
||||
else:
|
||||
os.remove(item)
|
||||
# see if a yaml file with same name exists
|
||||
yaml_file = os.path.splitext(item)[0] + ".yaml"
|
||||
if os.path.exists(yaml_file):
|
||||
os.remove(yaml_file)
|
||||
if combined_items:
|
||||
latest_item = combined_items[-1]
|
||||
return latest_item
|
||||
|
||||
def post_save_hook(self, save_path):
|
||||
# override in subclass
|
||||
@@ -366,6 +377,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
if self.save_config.save_format == "diffusers":
|
||||
# saving as a folder path
|
||||
file_path = file_path.replace('.safetensors', '')
|
||||
# convert it back to normal object
|
||||
save_meta = parse_metadata_from_safetensors(save_meta)
|
||||
self.sd.save(
|
||||
file_path,
|
||||
save_meta,
|
||||
@@ -415,24 +431,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if name == None:
|
||||
name = self.job.name
|
||||
# get latest saved step
|
||||
latest_path = None
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{name}*{post}.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
# try pt
|
||||
pattern = f"{name}*.pt"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
# Define patterns for both files and directories
|
||||
patterns = [
|
||||
f"{name}*{post}.safetensors",
|
||||
f"{name}*{post}.pt",
|
||||
f"{name}*{post}"
|
||||
]
|
||||
# Search for both files and directories
|
||||
paths = []
|
||||
for pattern in patterns:
|
||||
paths.extend(glob.glob(os.path.join(self.save_root, pattern)))
|
||||
|
||||
# Filter out non-existent paths and sort by creation time
|
||||
if paths:
|
||||
paths = [p for p in paths if os.path.exists(p)]
|
||||
latest_path = max(paths, key=os.path.getctime)
|
||||
|
||||
return latest_path
|
||||
|
||||
def load_training_state_from_metadata(self, path):
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if path is folder, then it is diffusers
|
||||
if os.path.isdir(path):
|
||||
meta_path = os.path.join(path, 'aitk_meta.yaml')
|
||||
# load it
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = yaml.load(f, Loader=yaml.FullLoader)
|
||||
else:
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||
self.step_num = meta['training_info']['step']
|
||||
@@ -731,14 +758,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
Reference in New Issue
Block a user