Use peft format for flux loras so they are compatible with diffusers. allow loading an assistant lora

This commit is contained in:
Jaret Burkett
2024-08-05 14:34:37 -06:00
parent edb7e827ee
commit 187663ab55
4 changed files with 87 additions and 6 deletions

View File

@@ -379,6 +379,8 @@ class ModelConfig:
self._original_refiner_name_or_path = self.refiner_name_or_path
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
self.lora_path = kwargs.get('lora_path', None)
# mainly for decompression loras for distilled models
self.assistant_lora_path = kwargs.get('assistant_lora_path', None)
self.latent_space_version = kwargs.get('latent_space_version', None)
# only for SDXL models for now

View File

@@ -124,6 +124,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
PEFT_PREFIX_UNET = "unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
@@ -171,6 +172,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
network_type: str = "lora",
full_train_in_out: bool = False,
transformer_only: bool = False,
peft_format: bool = False,
**kwargs
) -> None:
"""
@@ -223,6 +225,17 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.module_class = DoRAModule
module_class = DoRAModule
self.peft_format = peft_format
# always do peft for flux only for now
if self.is_flux:
self.peft_format = True
if self.peft_format:
# no alpha for peft
self.alpha = self.lora_dim
self.conv_alpha = self.conv_lora_dim
self.full_train_in_out = full_train_in_out
if modules_dim is not None:
@@ -252,8 +265,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
unet_prefix = self.LORA_PREFIX_UNET
if self.peft_format:
unet_prefix = self.PEFT_PREFIX_UNET
if is_pixart or is_v3 or is_auraflow or is_flux:
unet_prefix = f"lora_transformer"
if self.peft_format:
unet_prefix = "transformer"
prefix = (
unet_prefix
@@ -282,7 +299,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
lora_name = ".".join(lora_name)
# if it doesnt have a name, it wil have two dots
lora_name.replace("..", ".")
lora_name = lora_name.replace(".", "_")
if self.peft_format:
# we replace this on saving
lora_name = lora_name.replace(".", "$$")
else:
lora_name = lora_name.replace(".", "_")
skip = False
if any([word in child_name for word in self.ignore_if_contains]):

View File

@@ -204,7 +204,6 @@ class ToolkitModuleMixin:
return lx * scale
def lorm_forward(self: Network, x, *args, **kwargs):
network: Network = self.network_ref()
if not network.is_active:
@@ -492,6 +491,24 @@ class ToolkitNetworkMixin:
v = v.detach().clone().to("cpu").to(dtype)
save_dict[key] = v
if self.peft_format:
# lora_down = lora_A
# lora_up = lora_B
# no alpha
new_save_dict = {}
for key, value in save_dict.items():
if key.endswith('.alpha'):
continue
new_key = key
new_key = new_key.replace('lora_down', 'lora_A')
new_key = new_key.replace('lora_up', 'lora_B')
# replace all $$ with .
new_key = new_key.replace('$$', '.')
new_save_dict[new_key] = value
save_dict = new_save_dict
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(state_dict, metadata)
@@ -519,6 +536,20 @@ class ToolkitNetworkMixin:
# replace old double __ with single _
if self.is_pixart:
load_key = load_key.replace('__', '_')
if self.peft_format:
# lora_down = lora_A
# lora_up = lora_B
# no alpha
if load_key.endswith('.alpha'):
continue
load_key = load_key.replace('lora_A', 'lora_down')
load_key = load_key.replace('lora_B', 'lora_up')
# replace all . with $$
load_key = load_key.replace('.', '$$')
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
load_sd[load_key] = value
# extract extra items from state dict
@@ -533,7 +564,8 @@ class ToolkitNetworkMixin:
del load_sd[key]
print(f"Missing keys: {to_delete}")
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (len(to_delete) == 1 and 'emb_params' in to_delete):
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (
len(to_delete) == 1 and 'emb_params' in to_delete):
print(" Attempting to load with forced keymap")
return self.load_weights(file, force_weight_mapping=True)
@@ -657,4 +689,3 @@ class ToolkitNetworkMixin:
params_reduced += (num_orig_module_params - num_lorem_params)
return params_reduced

View File

@@ -616,7 +616,17 @@ class StableDiffusion:
if self.model_config.lora_path is not None:
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
self.unet.fuse_lora()
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
if self.model_config.assistant_lora_path is not None:
if self.model_config.lora_path is not None:
raise ValueError("Cannot have both lora and assistant lora")
print("Loading assistant lora")
pipe.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
pipe.fuse_lora(lora_scale=1.0)
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
@@ -690,7 +700,15 @@ class StableDiffusion:
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
):
merge_multiplier = 1.0
# sample_folder = os.path.join(self.save_root, 'samples')
# if using assistant, unfuse it
if self.model_config.assistant_lora_path is not None:
print("Unloading asistant lora")
# unfortunately, not an easier way with peft
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
self.pipeline.fuse_lora(lora_scale=-1.0)
self.pipeline.unload_lora_weights()
if self.network is not None:
self.network.eval()
network = self.network
@@ -1162,6 +1180,14 @@ class StableDiffusion:
network.merge_out(merge_multiplier)
# self.tokenizer.to(original_device_dict['tokenizer'])
# refuse loras
if self.model_config.assistant_lora_path is not None:
print("Loading asistant lora")
# unfortunately, not an easier way with peft
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
self.pipeline.fuse_lora(lora_scale=1.0)
self.pipeline.unload_lora_weights()
def get_latent_noise(
self,
height=None,