mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes
This commit is contained in:
@@ -54,6 +54,7 @@ parser.add_argument('--name', type=str, default='stable_diffusion', help='name f
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--vega', action='store_true', help='is vega model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -66,15 +67,15 @@ print(f'Loading diffusers model')
|
||||
|
||||
ignore_ldm_begins_with = []
|
||||
|
||||
diffusers_file_path = file_path
|
||||
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
|
||||
if args.ssd:
|
||||
diffusers_file_path = "segmind/SSD-1B"
|
||||
if args.vega:
|
||||
diffusers_file_path = "segmind/Segmind-Vega"
|
||||
|
||||
# if args.refiner:
|
||||
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||
|
||||
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
|
||||
|
||||
if not args.refiner:
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
@@ -82,6 +83,7 @@ if not args.refiner:
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
is_vega=args.vega,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
@@ -157,7 +159,7 @@ te_suffix = ''
|
||||
proj_pattern_weight = None
|
||||
proj_pattern_bias = None
|
||||
text_proj_layer = None
|
||||
if args.sdxl or args.ssd:
|
||||
if args.sdxl or args.ssd or args.vega:
|
||||
te_suffix = '1'
|
||||
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
@@ -176,10 +178,13 @@ if args.sd2:
|
||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||
|
||||
if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||
if args.sdxl or args.sd2 or args.ssd or args.refiner or args.vega:
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
elif "conditioner.embedders.1.model.text_projection.weight" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection.weight"].shape[0])
|
||||
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
|
||||
@@ -191,6 +196,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||
try:
|
||||
match = re.match(proj_pattern_weight, ldm_key)
|
||||
if match:
|
||||
if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight":
|
||||
print("here")
|
||||
number = int(match.group(1))
|
||||
new_val = torch.cat([
|
||||
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
|
||||
@@ -217,6 +224,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||
],
|
||||
}
|
||||
|
||||
matched_ldm_keys.append(ldm_key)
|
||||
|
||||
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
|
||||
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
|
||||
@@ -266,6 +275,8 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||
],
|
||||
}
|
||||
|
||||
matched_ldm_keys.append(ldm_key)
|
||||
|
||||
# add diffusers operators
|
||||
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = {
|
||||
"slice": [
|
||||
@@ -298,6 +309,9 @@ for ldm_key in ldm_dict_keys:
|
||||
ldm_shape_tuple = ldm_state_dict[ldm_key].shape
|
||||
ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple)
|
||||
for diffusers_key in diffusers_dict_keys:
|
||||
if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight" and diffusers_key == "te1_text_model.encoder.layers.0.self_attn.q_proj.weight":
|
||||
print("here")
|
||||
|
||||
diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape
|
||||
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
|
||||
|
||||
@@ -356,6 +370,8 @@ if args.sdxl:
|
||||
name += '_sdxl'
|
||||
elif args.ssd:
|
||||
name += '_ssd'
|
||||
elif args.vega:
|
||||
name += '_vega'
|
||||
elif args.refiner:
|
||||
name += '_refiner'
|
||||
elif args.sd2:
|
||||
|
||||
Reference in New Issue
Block a user