mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-28 08:13:58 +00:00
Add a method to have an inference only lora
This commit is contained in:
@@ -410,6 +410,7 @@ class ModelConfig:
|
||||
self.lora_path = kwargs.get('lora_path', None)
|
||||
# mainly for decompression loras for distilled models
|
||||
self.assistant_lora_path = kwargs.get('assistant_lora_path', None)
|
||||
self.inference_lora_path = kwargs.get('inference_lora_path', None)
|
||||
self.latent_space_version = kwargs.get('latent_space_version', None)
|
||||
|
||||
# only for SDXL models for now
|
||||
|
||||
@@ -497,33 +497,46 @@ class StableDiffusion:
|
||||
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None:
|
||||
raise ValueError("Cannot load both assistant lora and inference lora at the same time")
|
||||
|
||||
if self.model_config.lora_path:
|
||||
raise ValueError("Cannot load both assistant lora and lora at the same time")
|
||||
|
||||
if not self.is_flux:
|
||||
raise ValueError("Assistant lora is only supported for flux models currently")
|
||||
raise ValueError("Assistant/ inference lora is only supported for flux models currently")
|
||||
|
||||
load_lora_path = self.model_config.inference_lora_path
|
||||
if load_lora_path is None:
|
||||
load_lora_path = self.model_config.assistant_lora_path
|
||||
|
||||
if os.path.isdir(self.model_config.assistant_lora_path):
|
||||
self.model_config.assistant_lora_path = os.path.join(
|
||||
self.model_config.assistant_lora_path, "pytorch_lora_weights.safetensors"
|
||||
if os.path.isdir(load_lora_path):
|
||||
load_lora_path = os.path.join(
|
||||
load_lora_path, "pytorch_lora_weights.safetensors"
|
||||
)
|
||||
elif not os.path.exists(self.model_config.assistant_lora_path):
|
||||
print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}")
|
||||
elif not os.path.exists(load_lora_path):
|
||||
print(f"Grabbing lora from the hub: {load_lora_path}")
|
||||
new_lora_path = hf_hub_download(
|
||||
self.model_config.assistant_lora_path,
|
||||
load_lora_path,
|
||||
filename="pytorch_lora_weights.safetensors"
|
||||
)
|
||||
# replace the path
|
||||
self.model_config.assistant_lora_path = new_lora_path
|
||||
load_lora_path = new_lora_path
|
||||
|
||||
if self.model_config.inference_lora_path is not None:
|
||||
self.model_config.inference_lora_path = new_lora_path
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
self.model_config.assistant_lora_path = new_lora_path
|
||||
|
||||
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
|
||||
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
|
||||
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
|
||||
# so we will merge in now and sample with -1 weight later
|
||||
self.invert_assistant_lora = True
|
||||
# trigger it to get merged in
|
||||
self.model_config.lora_path = self.model_config.assistant_lora_path
|
||||
if self.model_config.assistant_lora_path is not None:
|
||||
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
|
||||
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
|
||||
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
|
||||
# so we will merge in now and sample with -1 weight later
|
||||
self.invert_assistant_lora = True
|
||||
# trigger it to get merged in
|
||||
self.model_config.lora_path = self.model_config.assistant_lora_path
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
print("Fusing in LoRA")
|
||||
@@ -763,6 +776,13 @@ class StableDiffusion:
|
||||
# invert and disable during training
|
||||
self.assistant_lora.multiplier = -1.0
|
||||
self.assistant_lora.is_active = False
|
||||
|
||||
if self.model_config.inference_lora_path is not None:
|
||||
print("Loading inference lora")
|
||||
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
||||
self.model_config.inference_lora_path, self)
|
||||
# disable during training
|
||||
self.assistant_lora.is_active = False
|
||||
|
||||
if self.is_pixart and self.vae_scale_factor == 16:
|
||||
# TODO make our own pipeline?
|
||||
@@ -840,6 +860,12 @@ class StableDiffusion:
|
||||
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
|
||||
else:
|
||||
self.assistant_lora.is_active = False
|
||||
|
||||
if self.model_config.inference_lora_path is not None:
|
||||
print("Loading inference lora")
|
||||
self.assistant_lora.is_active = True
|
||||
# move weights on to the device
|
||||
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
|
||||
|
||||
if self.network is not None:
|
||||
self.network.eval()
|
||||
@@ -1356,6 +1382,12 @@ class StableDiffusion:
|
||||
self.assistant_lora.force_to('cpu', self.torch_dtype)
|
||||
else:
|
||||
self.assistant_lora.is_active = True
|
||||
|
||||
if self.model_config.inference_lora_path is not None:
|
||||
print("Unloading inference lora")
|
||||
self.assistant_lora.is_active = False
|
||||
# move weights off the device
|
||||
self.assistant_lora.force_to('cpu', self.torch_dtype)
|
||||
|
||||
flush()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user