More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes

This commit is contained in:
Jaret Burkett
2023-12-15 06:02:10 -07:00
parent e5177833b2
commit 39870411d8
14 changed files with 3501 additions and 106 deletions

View File

@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
def get_ldm_state_dict_from_diffusers(
state_dict: 'OrderedDict',
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega', 'sdxl_refiner'] = '2',
device='cpu',
dtype=get_torch_dtype('fp32'),
):
@@ -115,6 +115,10 @@ def get_ldm_state_dict_from_diffusers(
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
elif sd_version == 'vega':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega.json')
elif sd_version == 'sdxl_refiner':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
@@ -137,7 +141,7 @@ def save_ldm_model_from_diffusers(
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2'
):
converted_state_dict = get_ldm_state_dict_from_diffusers(
sd.state_dict(),
@@ -156,11 +160,11 @@ def save_lora_from_diffusers(
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2'
):
converted_state_dict = OrderedDict()
# only handle sxdxl for now
if sd_version != 'sdxl' and sd_version != 'ssd':
if sd_version != 'sdxl' and sd_version != 'ssd' and sd_version != 'vega':
raise ValueError(f"Invalid sd_version {sd_version}")
for key, value in lora_state_dict.items():
# todo verify if this works with ssd