More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes

This commit is contained in:
Jaret Burkett
2023-12-15 06:02:10 -07:00
parent e5177833b2
commit 39870411d8
14 changed files with 3501 additions and 106 deletions

View File

@@ -54,6 +54,7 @@ parser.add_argument('--name', type=str, default='stable_diffusion', help='name f
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--vega', action='store_true', help='is vega model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
@@ -66,15 +67,15 @@ print(f'Loading diffusers model')
ignore_ldm_begins_with = []
diffusers_file_path = file_path
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
if args.ssd:
diffusers_file_path = "segmind/SSD-1B"
if args.vega:
diffusers_file_path = "segmind/Segmind-Vega"
# if args.refiner:
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
if not args.refiner:
diffusers_model_config = ModelConfig(
@@ -82,6 +83,7 @@ if not args.refiner:
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
is_vega=args.vega,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
@@ -157,7 +159,7 @@ te_suffix = ''
proj_pattern_weight = None
proj_pattern_bias = None
text_proj_layer = None
if args.sdxl or args.ssd:
if args.sdxl or args.ssd or args.vega:
te_suffix = '1'
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
@@ -176,10 +178,13 @@ if args.sd2:
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "cond_stage_model.model.text_projection"
if args.sdxl or args.sd2 or args.ssd or args.refiner:
if args.sdxl or args.sd2 or args.ssd or args.refiner or args.vega:
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
elif "conditioner.embedders.1.model.text_projection.weight" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection.weight"].shape[0])
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
@@ -191,6 +196,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
try:
match = re.match(proj_pattern_weight, ldm_key)
if match:
if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight":
print("here")
number = int(match.group(1))
new_val = torch.cat([
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
@@ -217,6 +224,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
],
}
matched_ldm_keys.append(ldm_key)
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
@@ -266,6 +275,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
],
}
matched_ldm_keys.append(ldm_key)
# add diffusers operators
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = {
"slice": [
@@ -298,6 +309,9 @@ for ldm_key in ldm_dict_keys:
ldm_shape_tuple = ldm_state_dict[ldm_key].shape
ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple)
for diffusers_key in diffusers_dict_keys:
if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight" and diffusers_key == "te1_text_model.encoder.layers.0.self_attn.q_proj.weight":
print("here")
diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
@@ -356,6 +370,8 @@ if args.sdxl:
name += '_sdxl'
elif args.ssd:
name += '_ssd'
elif args.vega:
name += '_vega'
elif args.refiner:
name += '_refiner'
elif args.sd2: