mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-10 23:49:57 +00:00
Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.
This commit is contained in:
@@ -2,11 +2,13 @@ import copy
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from diffusers import T2IAdapter
|
||||
from safetensors.torch import save_file, load_file
|
||||
# from lycoris.config import PRESET
|
||||
@@ -36,7 +38,8 @@ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \
|
||||
parse_metadata_from_safetensors
|
||||
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
|
||||
import gc
|
||||
|
||||
@@ -265,24 +268,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def clean_up_saves(self):
|
||||
# remove old saves
|
||||
# get latest saved step
|
||||
latest_item = None
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
|
||||
pattern = f"{self.job.name}_*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > self.save_config.max_step_saves_to_keep:
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
files.sort(key=os.path.getctime)
|
||||
for file in files[:-self.save_config.max_step_saves_to_keep]:
|
||||
self.print(f"Removing old save: {file}")
|
||||
os.remove(file)
|
||||
# see if a yaml file with same name exists
|
||||
yaml_file = os.path.splitext(file)[0] + ".yaml"
|
||||
if os.path.exists(yaml_file):
|
||||
os.remove(yaml_file)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
# pattern is {job_name}_{zero_filled_step} for both files and directories
|
||||
pattern = f"{self.job.name}_*"
|
||||
items = glob.glob(os.path.join(self.save_root, pattern))
|
||||
# Separate files and directories
|
||||
safetensors_files = [f for f in items if f.endswith('.safetensors')]
|
||||
directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
|
||||
# Combine the list and sort by creation time
|
||||
combined_items = safetensors_files + directories
|
||||
combined_items.sort(key=os.path.getctime)
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
for item in items_to_remove:
|
||||
self.print(f"Removing old save: {item}")
|
||||
if os.path.isdir(item):
|
||||
shutil.rmtree(item)
|
||||
else:
|
||||
os.remove(item)
|
||||
# see if a yaml file with same name exists
|
||||
yaml_file = os.path.splitext(item)[0] + ".yaml"
|
||||
if os.path.exists(yaml_file):
|
||||
os.remove(yaml_file)
|
||||
if combined_items:
|
||||
latest_item = combined_items[-1]
|
||||
return latest_item
|
||||
|
||||
def post_save_hook(self, save_path):
|
||||
# override in subclass
|
||||
@@ -366,6 +377,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
if self.save_config.save_format == "diffusers":
|
||||
# saving as a folder path
|
||||
file_path = file_path.replace('.safetensors', '')
|
||||
# convert it back to normal object
|
||||
save_meta = parse_metadata_from_safetensors(save_meta)
|
||||
self.sd.save(
|
||||
file_path,
|
||||
save_meta,
|
||||
@@ -415,24 +431,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if name == None:
|
||||
name = self.job.name
|
||||
# get latest saved step
|
||||
latest_path = None
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{name}*{post}.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
# try pt
|
||||
pattern = f"{name}*.pt"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
# Define patterns for both files and directories
|
||||
patterns = [
|
||||
f"{name}*{post}.safetensors",
|
||||
f"{name}*{post}.pt",
|
||||
f"{name}*{post}"
|
||||
]
|
||||
# Search for both files and directories
|
||||
paths = []
|
||||
for pattern in patterns:
|
||||
paths.extend(glob.glob(os.path.join(self.save_root, pattern)))
|
||||
|
||||
# Filter out non-existent paths and sort by creation time
|
||||
if paths:
|
||||
paths = [p for p in paths if os.path.exists(p)]
|
||||
latest_path = max(paths, key=os.path.getctime)
|
||||
|
||||
return latest_path
|
||||
|
||||
def load_training_state_from_metadata(self, path):
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if path is folder, then it is diffusers
|
||||
if os.path.isdir(path):
|
||||
meta_path = os.path.join(path, 'aitk_meta.yaml')
|
||||
# load it
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = yaml.load(f, Loader=yaml.FullLoader)
|
||||
else:
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||
self.step_num = meta['training_info']['step']
|
||||
@@ -731,14 +758,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
@@ -50,6 +50,7 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -60,10 +61,15 @@ find_matches = False
|
||||
|
||||
print(f'Loading diffusers model')
|
||||
|
||||
diffusers_file_path = file_path
|
||||
if args.ssd:
|
||||
diffusers_file_path = "segmind/SSD-1B"
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=file_path,
|
||||
name_or_path=diffusers_file_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
@@ -101,7 +107,7 @@ te_suffix = ''
|
||||
proj_pattern_weight = None
|
||||
proj_pattern_bias = None
|
||||
text_proj_layer = None
|
||||
if args.sdxl:
|
||||
if args.sdxl or args.ssd:
|
||||
te_suffix = '1'
|
||||
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
@@ -114,7 +120,7 @@ if args.sd2:
|
||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||
|
||||
if args.sdxl or args.sd2:
|
||||
if args.sdxl or args.sd2 or args.ssd:
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
@@ -289,6 +295,8 @@ pbar.close()
|
||||
name = args.name
|
||||
if args.sdxl:
|
||||
name += '_sdxl'
|
||||
elif args.ssd:
|
||||
name += '_ssd'
|
||||
elif args.sd2:
|
||||
name += '_sd2'
|
||||
else:
|
||||
|
||||
@@ -9,12 +9,17 @@ from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
|
||||
SaveFormat = Literal['safetensors', 'diffusers']
|
||||
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.save_every: int = kwargs.get('save_every', 1000)
|
||||
self.dtype: str = kwargs.get('save_dtype', 'float16')
|
||||
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
|
||||
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
|
||||
if self.save_format not in ['safetensors', 'diffusers']:
|
||||
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
|
||||
|
||||
|
||||
class LogingConfig:
|
||||
@@ -187,7 +192,7 @@ class TrainConfig:
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
|
||||
# match the norm of the noise before computing loss. This will help the model maintain its
|
||||
#current understandin of the brightness of images.
|
||||
# current understandin of the brightness of images.
|
||||
|
||||
self.match_noise_norm = kwargs.get('match_noise_norm', False)
|
||||
|
||||
@@ -229,6 +234,7 @@ class ModelConfig:
|
||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||
self.is_v2: bool = kwargs.get('is_v2', False)
|
||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||
self.is_ssd: bool = kwargs.get('is_ssd', False)
|
||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||
self.vae_path = kwargs.get('vae_path', None)
|
||||
@@ -242,6 +248,10 @@ class ModelConfig:
|
||||
if self.name_or_path is None:
|
||||
raise ValueError('name_or_path must be specified')
|
||||
|
||||
if self.is_ssd:
|
||||
# sed sdxl as true since it is mostly the same architecture
|
||||
self.is_xl = True
|
||||
|
||||
|
||||
class ReferenceDatasetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
3419
toolkit/keymaps/stable_diffusion_ssd.json
Normal file
3419
toolkit/keymaps/stable_diffusion_ssd.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors
Normal file
Binary file not shown.
21
toolkit/keymaps/stable_diffusion_ssd_unmatched.json
Normal file
21
toolkit/keymaps/stable_diffusion_ssd_unmatched.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"ldm": {
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
|
||||
"shape": [
|
||||
1,
|
||||
77
|
||||
],
|
||||
"min": 0.0,
|
||||
"max": 76.0
|
||||
},
|
||||
"conditioner.embedders.1.model.text_model.embeddings.position_ids": {
|
||||
"shape": [
|
||||
1,
|
||||
77
|
||||
],
|
||||
"min": 0.0,
|
||||
"max": 76.0
|
||||
}
|
||||
},
|
||||
"diffusers": {}
|
||||
}
|
||||
@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
|
||||
def get_ldm_state_dict_from_diffusers(
|
||||
state_dict: 'OrderedDict',
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2',
|
||||
device='cpu',
|
||||
dtype=get_torch_dtype('fp32'),
|
||||
):
|
||||
@@ -111,6 +111,10 @@ def get_ldm_state_dict_from_diffusers(
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||
elif sd_version == 'ssd':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
||||
else:
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
|
||||
@@ -129,7 +133,7 @@ def save_ldm_model_from_diffusers(
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
||||
):
|
||||
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||
sd.state_dict(),
|
||||
@@ -148,13 +152,14 @@ def save_lora_from_diffusers(
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
||||
):
|
||||
converted_state_dict = OrderedDict()
|
||||
# only handle sxdxl for now
|
||||
if sd_version != 'sdxl':
|
||||
if sd_version != 'sdxl' and sd_version != 'ssd':
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
for key, value in lora_state_dict.items():
|
||||
# todo verify if this works with ssd
|
||||
# test encoders share keys for some reason
|
||||
if key.begins_with('lora_te'):
|
||||
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
|
||||
|
||||
@@ -9,6 +9,7 @@ import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import yaml
|
||||
from PIL import Image
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file, load_file
|
||||
@@ -125,6 +126,7 @@ class StableDiffusion:
|
||||
self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
self.is_ssd = model_config.is_ssd
|
||||
|
||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||
@@ -157,7 +159,7 @@ class StableDiffusion:
|
||||
if self.model_config.vae_path is not None:
|
||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
if self.model_config.is_xl or self.model_config.is_ssd:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
@@ -165,7 +167,7 @@ class StableDiffusion:
|
||||
# pipln = StableDiffusionKDiffusionXLPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path):
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
@@ -176,19 +178,11 @@ class StableDiffusion:
|
||||
**load_args
|
||||
)
|
||||
else:
|
||||
try:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error loading model from single file. Trying to load from pretrained")
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
flush()
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
@@ -900,17 +894,34 @@ class StableDiffusion:
|
||||
version_string = '2'
|
||||
if self.is_xl:
|
||||
version_string = 'sdxl'
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
if self.is_ssd:
|
||||
# overwrite sdxl because both wil be true here
|
||||
version_string = 'ssd'
|
||||
# if output file does not end in .safetensors, then it is a directory and we are
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
else:
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user