Added refiner fine tuning. Works, but needs some polish.

This commit is contained in:
Jaret Burkett
2023-11-05 17:15:03 -07:00
parent 8a9e8f708f
commit 93ea955d7c
14 changed files with 4541 additions and 128 deletions

View File

@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
def get_ldm_state_dict_from_diffusers(
state_dict: 'OrderedDict',
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2',
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
device='cpu',
dtype=get_torch_dtype('fp32'),
):
@@ -115,6 +115,10 @@ def get_ldm_state_dict_from_diffusers(
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
elif sd_version == 'sdxl_refiner':
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json')
else:
raise ValueError(f"Invalid sd_version {sd_version}")