mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes
This commit is contained in:
@@ -97,7 +97,7 @@ def convert_state_dict_to_ldm_with_mapping(
|
||||
|
||||
def get_ldm_state_dict_from_diffusers(
|
||||
state_dict: 'OrderedDict',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'sdxl_refiner'] = '2',
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega', 'sdxl_refiner'] = '2',
|
||||
device='cpu',
|
||||
dtype=get_torch_dtype('fp32'),
|
||||
):
|
||||
@@ -115,6 +115,10 @@ def get_ldm_state_dict_from_diffusers(
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
|
||||
elif sd_version == 'vega':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega.json')
|
||||
elif sd_version == 'sdxl_refiner':
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
|
||||
@@ -137,7 +141,7 @@ def save_ldm_model_from_diffusers(
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2'
|
||||
):
|
||||
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||
sd.state_dict(),
|
||||
@@ -156,11 +160,11 @@ def save_lora_from_diffusers(
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
save_dtype=get_torch_dtype('fp16'),
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd'] = '2'
|
||||
sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2'
|
||||
):
|
||||
converted_state_dict = OrderedDict()
|
||||
# only handle sxdxl for now
|
||||
if sd_version != 'sdxl' and sd_version != 'ssd':
|
||||
if sd_version != 'sdxl' and sd_version != 'ssd' and sd_version != 'vega':
|
||||
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||
for key, value in lora_state_dict.items():
|
||||
# todo verify if this works with ssd
|
||||
|
||||
Reference in New Issue
Block a user