mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Added support for training ssd-1B. Added support for saving models into diffusers format. We can currently save in safetensors format for ssd-1b, but diffusers cannot load it yet.
This commit is contained in:
@@ -50,6 +50,7 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -60,10 +61,15 @@ find_matches = False
|
||||
|
||||
print(f'Loading diffusers model')
|
||||
|
||||
diffusers_file_path = file_path
|
||||
if args.ssd:
|
||||
diffusers_file_path = "segmind/SSD-1B"
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=file_path,
|
||||
name_or_path=diffusers_file_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
@@ -101,7 +107,7 @@ te_suffix = ''
|
||||
proj_pattern_weight = None
|
||||
proj_pattern_bias = None
|
||||
text_proj_layer = None
|
||||
if args.sdxl:
|
||||
if args.sdxl or args.ssd:
|
||||
te_suffix = '1'
|
||||
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
@@ -114,7 +120,7 @@ if args.sd2:
|
||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||
|
||||
if args.sdxl or args.sd2:
|
||||
if args.sdxl or args.sd2 or args.ssd:
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
@@ -289,6 +295,8 @@ pbar.close()
|
||||
name = args.name
|
||||
if args.sdxl:
|
||||
name += '_sdxl'
|
||||
elif args.ssd:
|
||||
name += '_ssd'
|
||||
elif args.sd2:
|
||||
name += '_sd2'
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user