mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added refiner fine tuning. Works, but needs some polish.
This commit is contained in:
@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
|
||||
def get_ldm_state_dict_from_diffusers(
|
||||
state_dict: 'OrderedDict',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
|
||||
device='cpu',
|
||||
dtype=get_torch_dtype('fp32'),
|
||||
):
|
||||
@@ -115,6 +115,10 @@ def get_ldm_state_dict_from_diffusers(
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
||||
elif sd_version == 'sdxl_refiner':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json')
|
||||
else:
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user