Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.

This commit is contained in:
Jaret Burkett
2023-11-03 05:01:16 -06:00
parent ceaf1d9454
commit d35733ac06
8 changed files with 3569 additions and 75 deletions

View File

@@ -2,11 +2,13 @@ import copy
import glob
import inspect
import json
import shutil
from collections import OrderedDict
import os
from typing import Union, List
import numpy as np
import yaml
from diffusers import T2IAdapter
from safetensors.torch import save_file, load_file
# from lycoris.config import PRESET
@@ -36,7 +38,8 @@ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \
parse_metadata_from_safetensors
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
import gc
@@ -265,24 +268,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
def clean_up_saves(self):
# remove old saves
# get latest saved step
latest_item = None
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
pattern = f"{self.job.name}_*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > self.save_config.max_step_saves_to_keep:
# remove all but the latest max_step_saves_to_keep
files.sort(key=os.path.getctime)
for file in files[:-self.save_config.max_step_saves_to_keep]:
self.print(f"Removing old save: {file}")
os.remove(file)
# see if a yaml file with same name exists
yaml_file = os.path.splitext(file)[0] + ".yaml"
if os.path.exists(yaml_file):
os.remove(yaml_file)
return latest_file
else:
return None
# pattern is {job_name}_{zero_filled_step} for both files and directories
pattern = f"{self.job.name}_*"
items = glob.glob(os.path.join(self.save_root, pattern))
# Separate files and directories
safetensors_files = [f for f in items if f.endswith('.safetensors')]
directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
# Combine the list and sort by creation time
combined_items = safetensors_files + directories
combined_items.sort(key=os.path.getctime)
# remove all but the latest max_step_saves_to_keep
items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
for item in items_to_remove:
self.print(f"Removing old save: {item}")
if os.path.isdir(item):
shutil.rmtree(item)
else:
os.remove(item)
# see if a yaml file with same name exists
yaml_file = os.path.splitext(item)[0] + ".yaml"
if os.path.exists(yaml_file):
os.remove(yaml_file)
if combined_items:
latest_item = combined_items[-1]
return latest_item
def post_save_hook(self, save_path):
# override in subclass
@@ -366,6 +377,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
if self.save_config.save_format == "diffusers":
# saving as a folder path
file_path = file_path.replace('.safetensors', '')
# convert it back to normal object
save_meta = parse_metadata_from_safetensors(save_meta)
self.sd.save(
file_path,
save_meta,
@@ -415,24 +431,35 @@ class BaseSDTrainProcess(BaseTrainProcess):
if name == None:
name = self.job.name
# get latest saved step
latest_path = None
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{name}*{post}.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
# try pt
pattern = f"{name}*.pt"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
return latest_file
else:
return None
# Define patterns for both files and directories
patterns = [
f"{name}*{post}.safetensors",
f"{name}*{post}.pt",
f"{name}*{post}"
]
# Search for both files and directories
paths = []
for pattern in patterns:
paths.extend(glob.glob(os.path.join(self.save_root, pattern)))
# Filter out non-existent paths and sort by creation time
if paths:
paths = [p for p in paths if os.path.exists(p)]
latest_path = max(paths, key=os.path.getctime)
return latest_path
def load_training_state_from_metadata(self, path):
meta = load_metadata_from_safetensors(path)
# if path is folder, then it is diffusers
if os.path.isdir(path):
meta_path = os.path.join(path, 'aitk_meta.yaml')
# load it
with open(meta_path, 'r') as f:
meta = yaml.load(f, Loader=yaml.FullLoader)
else:
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
self.step_num = meta['training_info']['step']
@@ -731,14 +758,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
self.load_training_state_from_metadata(latest_save_path)
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)

View File

@@ -50,6 +50,7 @@ parser.add_argument(
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
@@ -60,10 +61,15 @@ find_matches = False
print(f'Loading diffusers model')
diffusers_file_path = file_path
if args.ssd:
diffusers_file_path = "segmind/SSD-1B"
diffusers_model_config = ModelConfig(
name_or_path=file_path,
name_or_path=diffusers_file_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
@@ -101,7 +107,7 @@ te_suffix = ''
proj_pattern_weight = None
proj_pattern_bias = None
text_proj_layer = None
if args.sdxl:
if args.sdxl or args.ssd:
te_suffix = '1'
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
@@ -114,7 +120,7 @@ if args.sd2:
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "cond_stage_model.model.text_projection"
if args.sdxl or args.sd2:
if args.sdxl or args.sd2 or args.ssd:
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
@@ -289,6 +295,8 @@ pbar.close()
name = args.name
if args.sdxl:
name += '_sdxl'
elif args.ssd:
name += '_ssd'
elif args.sd2:
name += '_sd2'
else:

View File

@@ -9,12 +9,17 @@ from toolkit.prompt_utils import PromptEmbeds
ImgExt = Literal['jpg', 'png', 'webp']
SaveFormat = Literal['safetensors', 'diffusers']
class SaveConfig:
def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000)
self.dtype: str = kwargs.get('save_dtype', 'float16')
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
if self.save_format not in ['safetensors', 'diffusers']:
raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
class LogingConfig:
@@ -187,7 +192,7 @@ class TrainConfig:
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
# match the norm of the noise before computing loss. This will help the model maintain its
#current understandin of the brightness of images.
# current understandin of the brightness of images.
self.match_noise_norm = kwargs.get('match_noise_norm', False)
@@ -229,6 +234,7 @@ class ModelConfig:
self.name_or_path: str = kwargs.get('name_or_path', None)
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_ssd: bool = kwargs.get('is_ssd', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16')
self.vae_path = kwargs.get('vae_path', None)
@@ -242,6 +248,10 @@ class ModelConfig:
if self.name_or_path is None:
raise ValueError('name_or_path must be specified')
if self.is_ssd:
# sed sdxl as true since it is mostly the same architecture
self.is_xl = True
class ReferenceDatasetConfig:
def __init__(self, **kwargs):

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
{
"ldm": {
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
"shape": [
1,
77
],
"min": 0.0,
"max": 76.0
},
"conditioner.embedders.1.model.text_model.embeddings.position_ids": {
"shape": [
1,
77
],
"min": 0.0,
"max": 76.0
}
},
"diffusers": {}
}

View File

@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
def get_ldm_state_dict_from_diffusers(
state_dict: 'OrderedDict',
sd_version: Literal['1', '2', 'sdxl'] = '2',
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2',
device='cpu',
dtype=get_torch_dtype('fp32'),
):
@@ -111,6 +111,10 @@ def get_ldm_state_dict_from_diffusers(
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
elif sd_version == 'ssd':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
else:
raise ValueError(f"Invalid sd_version {sd_version}")
@@ -129,7 +133,7 @@ def save_ldm_model_from_diffusers(
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
):
converted_state_dict = get_ldm_state_dict_from_diffusers(
sd.state_dict(),
@@ -148,13 +152,14 @@ def save_lora_from_diffusers(
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
):
converted_state_dict = OrderedDict()
# only handle sxdxl for now
if sd_version != 'sdxl':
if sd_version != 'sdxl' and sd_version != 'ssd':
raise ValueError(f"Invalid sd_version {sd_version}")
for key, value in lora_state_dict.items():
# todo verify if this works with ssd
# test encoders share keys for some reason
if key.begins_with('lora_te'):
converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)

View File

@@ -9,6 +9,7 @@ import sys
import os
from collections import OrderedDict
import yaml
from PIL import Image
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file, load_file
@@ -125,6 +126,7 @@ class StableDiffusion:
self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.is_ssd = model_config.is_ssd
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -157,7 +159,7 @@ class StableDiffusion:
if self.model_config.vae_path is not None:
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
if self.model_config.is_xl:
if self.model_config.is_xl or self.model_config.is_ssd:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
else:
@@ -165,7 +167,7 @@ class StableDiffusion:
# pipln = StableDiffusionKDiffusionXLPipeline
# see if path exists
if not os.path.exists(model_path):
if not os.path.exists(model_path) or os.path.isdir(model_path):
# try to load with default diffusers
pipe = pipln.from_pretrained(
model_path,
@@ -176,19 +178,11 @@ class StableDiffusion:
**load_args
)
else:
try:
pipe = pipln.from_single_file(
model_path,
device=self.device_torch,
torch_dtype=self.torch_dtype,
)
except Exception as e:
print("Error loading model from single file. Trying to load from pretrained")
pipe = pipln.from_pretrained(
model_path,
device=self.device_torch,
torch_dtype=self.torch_dtype,
)
pipe = pipln.from_single_file(
model_path,
device=self.device_torch,
torch_dtype=self.torch_dtype,
)
flush()
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
@@ -900,17 +894,34 @@ class StableDiffusion:
version_string = '2'
if self.is_xl:
version_string = 'sdxl'
save_ldm_model_from_diffusers(
sd=self,
output_file=output_file,
meta=meta,
save_dtype=save_dtype,
sd_version=version_string,
)
if self.config_file is not None:
output_path_no_ext = os.path.splitext(output_file)[0]
output_config_path = f"{output_path_no_ext}.yaml"
shutil.copyfile(self.config_file, output_config_path)
if self.is_ssd:
# overwrite sdxl because both wil be true here
version_string = 'ssd'
# if output file does not end in .safetensors, then it is a directory and we are
# saving in diffusers format
if not output_file.endswith('.safetensors'):
# diffusers
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
else:
save_ldm_model_from_diffusers(
sd=self,
output_file=output_file,
meta=meta,
save_dtype=save_dtype,
sd_version=version_string,
)
if self.config_file is not None:
output_path_no_ext = os.path.splitext(output_file)[0]
output_config_path = f"{output_path_no_ext}.yaml"
shutil.copyfile(self.config_file, output_config_path)
def prepare_optimizer_params(
self,