Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.

This commit is contained in:
Jaret Burkett
2023-11-03 05:01:16 -06:00
parent ceaf1d9454
commit d35733ac06
8 changed files with 3569 additions and 75 deletions

View File

@@ -50,6 +50,7 @@ parser.add_argument(
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
@@ -60,10 +61,15 @@ find_matches = False
print(f'Loading diffusers model')
diffusers_file_path = file_path
if args.ssd:
diffusers_file_path = "segmind/SSD-1B"
diffusers_model_config = ModelConfig(
name_or_path=file_path,
name_or_path=diffusers_file_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
@@ -101,7 +107,7 @@ te_suffix = ''
proj_pattern_weight = None
proj_pattern_bias = None
text_proj_layer = None
if args.sdxl:
if args.sdxl or args.ssd:
te_suffix = '1'
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
@@ -114,7 +120,7 @@ if args.sd2:
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "cond_stage_model.model.text_projection"
if args.sdxl or args.sd2:
if args.sdxl or args.sd2 or args.ssd:
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
@@ -289,6 +295,8 @@ pbar.close()
name = args.name
if args.sdxl:
name += '_sdxl'
elif args.ssd:
name += '_ssd'
elif args.sd2:
name += '_sd2'
else: