mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Use peft format for flux loras so they are compatible with diffusers. allow loading an assistant lora
This commit is contained in:
@@ -379,6 +379,8 @@ class ModelConfig:
|
||||
self._original_refiner_name_or_path = self.refiner_name_or_path
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
|
||||
self.lora_path = kwargs.get('lora_path', None)
|
||||
# mainly for decompression loras for distilled models
|
||||
self.assistant_lora_path = kwargs.get('assistant_lora_path', None)
|
||||
self.latent_space_version = kwargs.get('latent_space_version', None)
|
||||
|
||||
# only for SDXL models for now
|
||||
|
||||
@@ -124,6 +124,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
PEFT_PREFIX_UNET = "unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
||||
@@ -171,6 +172,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
network_type: str = "lora",
|
||||
full_train_in_out: bool = False,
|
||||
transformer_only: bool = False,
|
||||
peft_format: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -223,6 +225,17 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
|
||||
self.peft_format = peft_format
|
||||
|
||||
# always do peft for flux only for now
|
||||
if self.is_flux:
|
||||
self.peft_format = True
|
||||
|
||||
if self.peft_format:
|
||||
# no alpha for peft
|
||||
self.alpha = self.lora_dim
|
||||
self.conv_alpha = self.conv_lora_dim
|
||||
|
||||
self.full_train_in_out = full_train_in_out
|
||||
|
||||
if modules_dim is not None:
|
||||
@@ -252,8 +265,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[LoRAModule]:
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if self.peft_format:
|
||||
unet_prefix = self.PEFT_PREFIX_UNET
|
||||
if is_pixart or is_v3 or is_auraflow or is_flux:
|
||||
unet_prefix = f"lora_transformer"
|
||||
if self.peft_format:
|
||||
unet_prefix = "transformer"
|
||||
|
||||
prefix = (
|
||||
unet_prefix
|
||||
@@ -282,7 +299,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
lora_name = ".".join(lora_name)
|
||||
# if it doesnt have a name, it wil have two dots
|
||||
lora_name.replace("..", ".")
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
if self.peft_format:
|
||||
# we replace this on saving
|
||||
lora_name = lora_name.replace(".", "$$")
|
||||
else:
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
|
||||
skip = False
|
||||
if any([word in child_name for word in self.ignore_if_contains]):
|
||||
|
||||
@@ -204,7 +204,6 @@ class ToolkitModuleMixin:
|
||||
|
||||
return lx * scale
|
||||
|
||||
|
||||
def lorm_forward(self: Network, x, *args, **kwargs):
|
||||
network: Network = self.network_ref()
|
||||
if not network.is_active:
|
||||
@@ -492,6 +491,24 @@ class ToolkitNetworkMixin:
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
save_dict[key] = v
|
||||
|
||||
if self.peft_format:
|
||||
# lora_down = lora_A
|
||||
# lora_up = lora_B
|
||||
# no alpha
|
||||
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
if key.endswith('.alpha'):
|
||||
continue
|
||||
new_key = key
|
||||
new_key = new_key.replace('lora_down', 'lora_A')
|
||||
new_key = new_key.replace('lora_up', 'lora_B')
|
||||
# replace all $$ with .
|
||||
new_key = new_key.replace('$$', '.')
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||
@@ -519,6 +536,20 @@ class ToolkitNetworkMixin:
|
||||
# replace old double __ with single _
|
||||
if self.is_pixart:
|
||||
load_key = load_key.replace('__', '_')
|
||||
|
||||
if self.peft_format:
|
||||
# lora_down = lora_A
|
||||
# lora_up = lora_B
|
||||
# no alpha
|
||||
if load_key.endswith('.alpha'):
|
||||
continue
|
||||
load_key = load_key.replace('lora_A', 'lora_down')
|
||||
load_key = load_key.replace('lora_B', 'lora_up')
|
||||
# replace all . with $$
|
||||
load_key = load_key.replace('.', '$$')
|
||||
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
||||
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
||||
|
||||
load_sd[load_key] = value
|
||||
|
||||
# extract extra items from state dict
|
||||
@@ -533,7 +564,8 @@ class ToolkitNetworkMixin:
|
||||
del load_sd[key]
|
||||
|
||||
print(f"Missing keys: {to_delete}")
|
||||
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (len(to_delete) == 1 and 'emb_params' in to_delete):
|
||||
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (
|
||||
len(to_delete) == 1 and 'emb_params' in to_delete):
|
||||
print(" Attempting to load with forced keymap")
|
||||
return self.load_weights(file, force_weight_mapping=True)
|
||||
|
||||
@@ -657,4 +689,3 @@ class ToolkitNetworkMixin:
|
||||
params_reduced += (num_orig_module_params - num_lorem_params)
|
||||
|
||||
return params_reduced
|
||||
|
||||
|
||||
@@ -616,7 +616,17 @@ class StableDiffusion:
|
||||
if self.model_config.lora_path is not None:
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
self.unet.fuse_lora()
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError("Cannot have both lora and assistant lora")
|
||||
print("Loading assistant lora")
|
||||
pipe.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
@@ -690,7 +700,15 @@ class StableDiffusion:
|
||||
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
||||
):
|
||||
merge_multiplier = 1.0
|
||||
# sample_folder = os.path.join(self.save_root, 'samples')
|
||||
|
||||
# if using assistant, unfuse it
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Unloading asistant lora")
|
||||
# unfortunately, not an easier way with peft
|
||||
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
||||
self.pipeline.fuse_lora(lora_scale=-1.0)
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
if self.network is not None:
|
||||
self.network.eval()
|
||||
network = self.network
|
||||
@@ -1162,6 +1180,14 @@ class StableDiffusion:
|
||||
network.merge_out(merge_multiplier)
|
||||
# self.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
# refuse loras
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
print("Loading asistant lora")
|
||||
# unfortunately, not an easier way with peft
|
||||
self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora")
|
||||
self.pipeline.fuse_lora(lora_scale=1.0)
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
def get_latent_noise(
|
||||
self,
|
||||
height=None,
|
||||
|
||||
Reference in New Issue
Block a user