mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
from typing import TYPE_CHECKING
|
|
from toolkit.config_modules import NetworkConfig
|
|
from toolkit.lora_special import LoRASpecialNetwork
|
|
from safetensors.torch import load_file
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.stable_diffusion_model import StableDiffusion
|
|
|
|
|
|
def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork:
|
|
if not sd.is_flux:
|
|
raise ValueError("Only Flux models can load assistant adapters currently.")
|
|
pipe = sd.pipeline
|
|
print(f"Loading assistant adapter from {adapter_path}")
|
|
adapter_name = adapter_path.split("/")[-1].split(".")[0]
|
|
lora_state_dict = load_file(adapter_path)
|
|
|
|
linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0])
|
|
# linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item())
|
|
linear_alpha = linear_dim
|
|
transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict
|
|
# get dim and scale
|
|
network_config = NetworkConfig(
|
|
linear=linear_dim,
|
|
linear_alpha=linear_alpha,
|
|
transformer_only=transformer_only,
|
|
)
|
|
|
|
network = LoRASpecialNetwork(
|
|
text_encoder=pipe.text_encoder,
|
|
unet=pipe.transformer,
|
|
lora_dim=network_config.linear,
|
|
multiplier=1.0,
|
|
alpha=network_config.linear_alpha,
|
|
train_unet=True,
|
|
train_text_encoder=False,
|
|
is_flux=True,
|
|
network_config=network_config,
|
|
network_type=network_config.type,
|
|
transformer_only=network_config.transformer_only,
|
|
is_assistant_adapter=True
|
|
)
|
|
network.apply_to(
|
|
pipe.text_encoder,
|
|
pipe.transformer,
|
|
apply_text_encoder=False,
|
|
apply_unet=True
|
|
)
|
|
network.force_to(sd.device_torch, dtype=sd.torch_dtype)
|
|
network.eval()
|
|
network._update_torch_multiplier()
|
|
network.load_weights(lora_state_dict)
|
|
network.is_active = True
|
|
|
|
return network
|