mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.
This commit is contained in:
@@ -9,6 +9,7 @@ import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import yaml
|
||||
from PIL import Image
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file, load_file
|
||||
@@ -125,6 +126,7 @@ class StableDiffusion:
|
||||
self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
self.is_ssd = model_config.is_ssd
|
||||
|
||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||
@@ -157,7 +159,7 @@ class StableDiffusion:
|
||||
if self.model_config.vae_path is not None:
|
||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||
|
||||
if self.model_config.is_xl:
|
||||
if self.model_config.is_xl or self.model_config.is_ssd:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
else:
|
||||
@@ -165,7 +167,7 @@ class StableDiffusion:
|
||||
# pipln = StableDiffusionKDiffusionXLPipeline
|
||||
|
||||
# see if path exists
|
||||
if not os.path.exists(model_path):
|
||||
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
||||
# try to load with default diffusers
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
@@ -176,19 +178,11 @@ class StableDiffusion:
|
||||
**load_args
|
||||
)
|
||||
else:
|
||||
try:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error loading model from single file. Trying to load from pretrained")
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
flush()
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
@@ -900,17 +894,34 @@ class StableDiffusion:
|
||||
version_string = '2'
|
||||
if self.is_xl:
|
||||
version_string = 'sdxl'
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
if self.is_ssd:
|
||||
# overwrite sdxl because both wil be true here
|
||||
version_string = 'ssd'
|
||||
# if output file does not end in .safetensors, then it is a directory and we are
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
else:
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
if self.config_file is not None:
|
||||
output_path_no_ext = os.path.splitext(output_file)[0]
|
||||
output_config_path = f"{output_path_no_ext}.yaml"
|
||||
shutil.copyfile(self.config_file, output_config_path)
|
||||
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user