mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
246 lines
8.3 KiB
Python
246 lines
8.3 KiB
Python
import os
|
|
from tqdm import tqdm
|
|
import argparse
|
|
from collections import OrderedDict
|
|
|
|
parser = argparse.ArgumentParser(description="Extract LoRA from Flex")
|
|
parser.add_argument("--base", type=str, default="ostris/Flex.1-alpha", help="Base model path")
|
|
parser.add_argument("--tuned", type=str, required=True, help="Tuned model path")
|
|
parser.add_argument("--output", type=str, required=True, help="Output path for lora")
|
|
parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction")
|
|
parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction")
|
|
parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if True:
|
|
# set cuda environment variable
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from lycoris.utils import extract_linear, extract_conv, make_sparse
|
|
from diffusers import FluxTransformer2DModel
|
|
|
|
base = args.base
|
|
tuned = args.tuned
|
|
output_path = args.output
|
|
dim = args.rank
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
|
state_dict_base = {}
|
|
state_dict_tuned = {}
|
|
|
|
output_dict = {}
|
|
|
|
@torch.no_grad()
|
|
def extract_diff(
|
|
base_unet,
|
|
db_unet,
|
|
mode="fixed",
|
|
linear_mode_param=0,
|
|
conv_mode_param=0,
|
|
extract_device="cpu",
|
|
use_bias=False,
|
|
sparsity=0.98,
|
|
# small_conv=True,
|
|
small_conv=False,
|
|
):
|
|
UNET_TARGET_REPLACE_MODULE = [
|
|
"Linear",
|
|
"Conv2d",
|
|
"LayerNorm",
|
|
"GroupNorm",
|
|
"GroupNorm32",
|
|
"LoRACompatibleLinear",
|
|
"LoRACompatibleConv"
|
|
]
|
|
LORA_PREFIX_UNET = "transformer"
|
|
|
|
def make_state_dict(
|
|
prefix,
|
|
root_module: torch.nn.Module,
|
|
target_module: torch.nn.Module,
|
|
target_replace_modules,
|
|
):
|
|
loras = {}
|
|
temp = {}
|
|
|
|
for name, module in root_module.named_modules():
|
|
if module.__class__.__name__ in target_replace_modules:
|
|
temp[name] = module
|
|
|
|
for name, module in tqdm(
|
|
list((n, m) for n, m in target_module.named_modules() if n in temp)
|
|
):
|
|
weights = temp[name]
|
|
lora_name = prefix + "." + name
|
|
# lora_name = lora_name.replace(".", "_")
|
|
layer = module.__class__.__name__
|
|
if 'transformer_blocks' not in lora_name and not args.full:
|
|
continue
|
|
|
|
if layer in {
|
|
"Linear",
|
|
"Conv2d",
|
|
"LayerNorm",
|
|
"GroupNorm",
|
|
"GroupNorm32",
|
|
"Embedding",
|
|
"LoRACompatibleLinear",
|
|
"LoRACompatibleConv"
|
|
}:
|
|
root_weight = module.weight
|
|
try:
|
|
if torch.allclose(root_weight, weights.weight):
|
|
continue
|
|
except:
|
|
continue
|
|
else:
|
|
continue
|
|
module = module.to(extract_device, torch.float32)
|
|
weights = weights.to(extract_device, torch.float32)
|
|
|
|
if mode == "full":
|
|
decompose_mode = "full"
|
|
elif layer == "Linear":
|
|
weight, decompose_mode = extract_linear(
|
|
(root_weight - weights.weight),
|
|
mode,
|
|
linear_mode_param,
|
|
device=extract_device,
|
|
)
|
|
if decompose_mode == "low rank":
|
|
extract_a, extract_b, diff = weight
|
|
elif layer == "Conv2d":
|
|
is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1
|
|
weight, decompose_mode = extract_conv(
|
|
(root_weight - weights.weight),
|
|
mode,
|
|
linear_mode_param if is_linear else conv_mode_param,
|
|
device=extract_device,
|
|
)
|
|
if decompose_mode == "low rank":
|
|
extract_a, extract_b, diff = weight
|
|
if small_conv and not is_linear and decompose_mode == "low rank":
|
|
dim = extract_a.size(0)
|
|
(extract_c, extract_a, _), _ = extract_conv(
|
|
extract_a.transpose(0, 1),
|
|
"fixed",
|
|
dim,
|
|
extract_device,
|
|
True,
|
|
)
|
|
extract_a = extract_a.transpose(0, 1)
|
|
extract_c = extract_c.transpose(0, 1)
|
|
loras[f"{lora_name}.lora_mid.weight"] = (
|
|
extract_c.detach().cpu().contiguous().half()
|
|
)
|
|
diff = (
|
|
(
|
|
root_weight
|
|
- torch.einsum(
|
|
"i j k l, j r, p i -> p r k l",
|
|
extract_c,
|
|
extract_a.flatten(1, -1),
|
|
extract_b.flatten(1, -1),
|
|
)
|
|
)
|
|
.detach()
|
|
.cpu()
|
|
.contiguous()
|
|
)
|
|
del extract_c
|
|
else:
|
|
module = module.to("cpu")
|
|
weights = weights.to("cpu")
|
|
continue
|
|
|
|
if decompose_mode == "low rank":
|
|
loras[f"{lora_name}.lora_A.weight"] = (
|
|
extract_a.detach().cpu().contiguous().half()
|
|
)
|
|
loras[f"{lora_name}.lora_B.weight"] = (
|
|
extract_b.detach().cpu().contiguous().half()
|
|
)
|
|
# loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half()
|
|
if use_bias:
|
|
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
|
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
|
|
|
indices = sparse_diff.indices().to(torch.int16)
|
|
values = sparse_diff.values().half()
|
|
loras[f"{lora_name}.bias_indices"] = indices
|
|
loras[f"{lora_name}.bias_values"] = values
|
|
loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to(
|
|
torch.int16
|
|
)
|
|
del extract_a, extract_b, diff
|
|
elif decompose_mode == "full":
|
|
if "Norm" in layer:
|
|
w_key = "w_norm"
|
|
b_key = "b_norm"
|
|
else:
|
|
w_key = "diff"
|
|
b_key = "diff_b"
|
|
weight_diff = module.weight - weights.weight
|
|
loras[f"{lora_name}.{w_key}"] = (
|
|
weight_diff.detach().cpu().contiguous().half()
|
|
)
|
|
if getattr(weights, "bias", None) is not None:
|
|
bias_diff = module.bias - weights.bias
|
|
loras[f"{lora_name}.{b_key}"] = (
|
|
bias_diff.detach().cpu().contiguous().half()
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
module = module.to("cpu", torch.bfloat16)
|
|
weights = weights.to("cpu", torch.bfloat16)
|
|
return loras
|
|
|
|
all_loras = {}
|
|
|
|
all_loras |= make_state_dict(
|
|
LORA_PREFIX_UNET,
|
|
base_unet,
|
|
db_unet,
|
|
UNET_TARGET_REPLACE_MODULE,
|
|
)
|
|
del base_unet, db_unet
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
all_lora_name = set()
|
|
for k in all_loras:
|
|
lora_name, weight = k.rsplit(".", 1)
|
|
all_lora_name.add(lora_name)
|
|
print(len(all_lora_name))
|
|
return all_loras
|
|
|
|
|
|
# find all the .safetensors files and load them
|
|
print("Loading Base")
|
|
base_model = FluxTransformer2DModel.from_pretrained(base, subfolder="transformer", torch_dtype=torch.bfloat16)
|
|
|
|
print("Loading Tuned")
|
|
tuned_model = FluxTransformer2DModel.from_pretrained(tuned, subfolder="transformer", torch_dtype=torch.bfloat16)
|
|
|
|
output_dict = extract_diff(
|
|
base_model,
|
|
tuned_model,
|
|
mode="fixed",
|
|
linear_mode_param=dim,
|
|
conv_mode_param=dim,
|
|
extract_device="cuda",
|
|
use_bias=False,
|
|
sparsity=0.98,
|
|
small_conv=False,
|
|
)
|
|
|
|
meta = OrderedDict()
|
|
meta['format'] = 'pt'
|
|
|
|
save_file(output_dict, output_path, metadata=meta)
|
|
|
|
print("Done")
|