Add a method to have an inference only lora

This commit is contained in:
Jaret Burkett
2024-10-04 10:06:53 -06:00
parent 28e6f00790
commit a800c9d19e
2 changed files with 49 additions and 16 deletions

View File

@@ -410,6 +410,7 @@ class ModelConfig:
self.lora_path = kwargs.get('lora_path', None)
# mainly for decompression loras for distilled models
self.assistant_lora_path = kwargs.get('assistant_lora_path', None)
self.inference_lora_path = kwargs.get('inference_lora_path', None)
self.latent_space_version = kwargs.get('latent_space_version', None)
# only for SDXL models for now

View File

@@ -497,33 +497,46 @@ class StableDiffusion:
transformer.to(torch.device(self.quantize_device), dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None:
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None:
raise ValueError("Cannot load both assistant lora and inference lora at the same time")
if self.model_config.lora_path:
raise ValueError("Cannot load both assistant lora and lora at the same time")
if not self.is_flux:
raise ValueError("Assistant lora is only supported for flux models currently")
raise ValueError("Assistant/ inference lora is only supported for flux models currently")
load_lora_path = self.model_config.inference_lora_path
if load_lora_path is None:
load_lora_path = self.model_config.assistant_lora_path
if os.path.isdir(self.model_config.assistant_lora_path):
self.model_config.assistant_lora_path = os.path.join(
self.model_config.assistant_lora_path, "pytorch_lora_weights.safetensors"
if os.path.isdir(load_lora_path):
load_lora_path = os.path.join(
load_lora_path, "pytorch_lora_weights.safetensors"
)
elif not os.path.exists(self.model_config.assistant_lora_path):
print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}")
elif not os.path.exists(load_lora_path):
print(f"Grabbing lora from the hub: {load_lora_path}")
new_lora_path = hf_hub_download(
self.model_config.assistant_lora_path,
load_lora_path,
filename="pytorch_lora_weights.safetensors"
)
# replace the path
self.model_config.assistant_lora_path = new_lora_path
load_lora_path = new_lora_path
if self.model_config.inference_lora_path is not None:
self.model_config.inference_lora_path = new_lora_path
if self.model_config.assistant_lora_path is not None:
self.model_config.assistant_lora_path = new_lora_path
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
# so we will merge in now and sample with -1 weight later
self.invert_assistant_lora = True
# trigger it to get merged in
self.model_config.lora_path = self.model_config.assistant_lora_path
if self.model_config.assistant_lora_path is not None:
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
# so we will merge in now and sample with -1 weight later
self.invert_assistant_lora = True
# trigger it to get merged in
self.model_config.lora_path = self.model_config.assistant_lora_path
if self.model_config.lora_path is not None:
print("Fusing in LoRA")
@@ -763,6 +776,13 @@ class StableDiffusion:
# invert and disable during training
self.assistant_lora.multiplier = -1.0
self.assistant_lora.is_active = False
if self.model_config.inference_lora_path is not None:
print("Loading inference lora")
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
self.model_config.inference_lora_path, self)
# disable during training
self.assistant_lora.is_active = False
if self.is_pixart and self.vae_scale_factor == 16:
# TODO make our own pipeline?
@@ -840,6 +860,12 @@ class StableDiffusion:
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
else:
self.assistant_lora.is_active = False
if self.model_config.inference_lora_path is not None:
print("Loading inference lora")
self.assistant_lora.is_active = True
# move weights on to the device
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
if self.network is not None:
self.network.eval()
@@ -1356,6 +1382,12 @@ class StableDiffusion:
self.assistant_lora.force_to('cpu', self.torch_dtype)
else:
self.assistant_lora.is_active = True
if self.model_config.inference_lora_path is not None:
print("Unloading inference lora")
self.assistant_lora.is_active = False
# move weights off the device
self.assistant_lora.force_to('cpu', self.torch_dtype)
flush()