mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a config file for full finetuning flex. Added a lora extraction script for flex
This commit is contained in:
103
config/examples/train_full_fine_tune_flex.yaml
Normal file
103
config/examples/train_full_fine_tune_flex.yaml
Normal file
@@ -0,0 +1,103 @@
|
||||
---
|
||||
# This configuration requires 48GB of VRAM or more to operate
|
||||
job: extension
|
||||
config:
|
||||
# this name will be the folder and filename name
|
||||
name: "my_first_flex_finetune_v1"
|
||||
process:
|
||||
- type: 'sd_trainer'
|
||||
# root folder to save training sessions/samples/weights
|
||||
training_folder: "output"
|
||||
# uncomment to see performance stats in the terminal every N steps
|
||||
# performance_log_every: 1000
|
||||
device: cuda:0
|
||||
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
||||
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
||||
# trigger_word: "p3r5on"
|
||||
save:
|
||||
dtype: bf16 # precision to save
|
||||
save_every: 250 # save every this many steps
|
||||
max_step_saves_to_keep: 2 # how many intermittent saves to keep
|
||||
save_format: 'diffusers' # 'diffusers'
|
||||
datasets:
|
||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||
# images will automatically be resized and bucketed into the resolution specified
|
||||
# on windows, escape back slashes with another backslash so
|
||||
# "C:\\path\\to\\images\\folder"
|
||||
- folder_path: "/path/to/images/folder"
|
||||
caption_ext: "txt"
|
||||
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
||||
shuffle_tokens: false # shuffle caption order, split by commas
|
||||
# cache_latents_to_disk: true # leave this true unless you know what you're doing
|
||||
resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
|
||||
train:
|
||||
batch_size: 1
|
||||
# IMPORTANT! For Flex, you must bypass the guidance embedder during training
|
||||
bypass_guidance_embedding: true
|
||||
|
||||
# can be 'sigmoid', 'linear', or 'lognorm_blend'
|
||||
timestep_type: 'sigmoid'
|
||||
|
||||
steps: 2000 # total number of steps to train 500 - 4000 is a good range
|
||||
gradient_accumulation: 1
|
||||
train_unet: true
|
||||
train_text_encoder: false # probably won't work with flex
|
||||
gradient_checkpointing: true # need the on unless you have a ton of vram
|
||||
noise_scheduler: "flowmatch" # for training only
|
||||
optimizer: "adafactor"
|
||||
lr: 3e-5
|
||||
|
||||
# Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
|
||||
# 0.1 is 10% of paramiters active at easc step. Only works with adafactor
|
||||
|
||||
# do_paramiter_swapping: true
|
||||
# paramiter_swapping_factor: 0.9
|
||||
|
||||
# uncomment this to skip the pre training sample
|
||||
# skip_first_sample: true
|
||||
# uncomment to completely disable sampling
|
||||
# disable_sampling: true
|
||||
|
||||
# ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
|
||||
ema_config:
|
||||
use_ema: true
|
||||
ema_decay: 0.99
|
||||
|
||||
# will probably need this if gpu supports it for flex, other dtypes may not work correctly
|
||||
dtype: bf16
|
||||
model:
|
||||
# huggingface model name or path
|
||||
name_or_path: "ostris/Flex.1-alpha"
|
||||
is_flux: true # flex is flux architecture
|
||||
# full finetuning quantized models is a crapshoot and results in subpar outputs
|
||||
# quantize: true
|
||||
# you can quantize just the T5 text encoder here to save vram
|
||||
quantize_te: true
|
||||
sample:
|
||||
sampler: "flowmatch" # must match train.noise_scheduler
|
||||
sample_every: 250 # sample every this many steps
|
||||
width: 1024
|
||||
height: 1024
|
||||
prompts:
|
||||
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
||||
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
||||
- "woman with red hair, playing chess at the park, bomb going off in the background"
|
||||
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
|
||||
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
|
||||
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
|
||||
- "a bear building a log cabin in the snow covered mountains"
|
||||
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
|
||||
- "hipster man with a beard, building a chair, in a wood shop"
|
||||
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
|
||||
- "a man holding a sign that says, 'this is a sign'"
|
||||
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
|
||||
neg: "" # not used on flex
|
||||
seed: 42
|
||||
walk_seed: true
|
||||
guidance_scale: 4
|
||||
sample_steps: 25
|
||||
# you can add any additional meta info here. [name] is replaced with config name at top
|
||||
meta:
|
||||
name: "[name]"
|
||||
version: '1.0'
|
||||
244
scripts/extract_lora_from_flex.py
Normal file
244
scripts/extract_lora_from_flex.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
parser = argparse.ArgumentParser(description="Extract LoRA from Flex")
|
||||
parser.add_argument("--base", type=str, default="ostris/Flex.1-alpha", help="Base model path")
|
||||
parser.add_argument("--tuned", type=str, required=True, help="Tuned model path")
|
||||
parser.add_argument("--output", type=str, required=True, help="Output path for lora")
|
||||
parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction")
|
||||
parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if True:
|
||||
# set cuda environment variable
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from lycoris.utils import extract_linear, extract_conv, make_sparse
|
||||
from diffusers import FluxTransformer2DModel
|
||||
|
||||
base = args.base
|
||||
tuned = args.tuned
|
||||
output_path = args.output
|
||||
dim = args.rank
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
state_dict_base = {}
|
||||
state_dict_tuned = {}
|
||||
|
||||
output_dict = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_diff(
|
||||
base_unet,
|
||||
db_unet,
|
||||
mode="fixed",
|
||||
linear_mode_param=0,
|
||||
conv_mode_param=0,
|
||||
extract_device="cpu",
|
||||
use_bias=False,
|
||||
sparsity=0.98,
|
||||
# small_conv=True,
|
||||
small_conv=False,
|
||||
):
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Linear",
|
||||
"Conv2d",
|
||||
"LayerNorm",
|
||||
"GroupNorm",
|
||||
"GroupNorm32",
|
||||
"LoRACompatibleLinear",
|
||||
"LoRACompatibleConv"
|
||||
]
|
||||
LORA_PREFIX_UNET = "transformer"
|
||||
|
||||
def make_state_dict(
|
||||
prefix,
|
||||
root_module: torch.nn.Module,
|
||||
target_module: torch.nn.Module,
|
||||
target_replace_modules,
|
||||
):
|
||||
loras = {}
|
||||
temp = {}
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
temp[name] = module
|
||||
|
||||
for name, module in tqdm(
|
||||
list((n, m) for n, m in target_module.named_modules() if n in temp)
|
||||
):
|
||||
weights = temp[name]
|
||||
lora_name = prefix + "." + name
|
||||
# lora_name = lora_name.replace(".", "_")
|
||||
layer = module.__class__.__name__
|
||||
if 'transformer_blocks' not in lora_name:
|
||||
continue
|
||||
|
||||
if layer in {
|
||||
"Linear",
|
||||
"Conv2d",
|
||||
"LayerNorm",
|
||||
"GroupNorm",
|
||||
"GroupNorm32",
|
||||
"Embedding",
|
||||
"LoRACompatibleLinear",
|
||||
"LoRACompatibleConv"
|
||||
}:
|
||||
root_weight = module.weight
|
||||
try:
|
||||
if torch.allclose(root_weight, weights.weight):
|
||||
continue
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
module = module.to(extract_device, torch.float32)
|
||||
weights = weights.to(extract_device, torch.float32)
|
||||
|
||||
if mode == "full":
|
||||
decompose_mode = "full"
|
||||
elif layer == "Linear":
|
||||
weight, decompose_mode = extract_linear(
|
||||
(root_weight - weights.weight),
|
||||
mode,
|
||||
linear_mode_param,
|
||||
device=extract_device,
|
||||
)
|
||||
if decompose_mode == "low rank":
|
||||
extract_a, extract_b, diff = weight
|
||||
elif layer == "Conv2d":
|
||||
is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1
|
||||
weight, decompose_mode = extract_conv(
|
||||
(root_weight - weights.weight),
|
||||
mode,
|
||||
linear_mode_param if is_linear else conv_mode_param,
|
||||
device=extract_device,
|
||||
)
|
||||
if decompose_mode == "low rank":
|
||||
extract_a, extract_b, diff = weight
|
||||
if small_conv and not is_linear and decompose_mode == "low rank":
|
||||
dim = extract_a.size(0)
|
||||
(extract_c, extract_a, _), _ = extract_conv(
|
||||
extract_a.transpose(0, 1),
|
||||
"fixed",
|
||||
dim,
|
||||
extract_device,
|
||||
True,
|
||||
)
|
||||
extract_a = extract_a.transpose(0, 1)
|
||||
extract_c = extract_c.transpose(0, 1)
|
||||
loras[f"{lora_name}.lora_mid.weight"] = (
|
||||
extract_c.detach().cpu().contiguous().half()
|
||||
)
|
||||
diff = (
|
||||
(
|
||||
root_weight
|
||||
- torch.einsum(
|
||||
"i j k l, j r, p i -> p r k l",
|
||||
extract_c,
|
||||
extract_a.flatten(1, -1),
|
||||
extract_b.flatten(1, -1),
|
||||
)
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.contiguous()
|
||||
)
|
||||
del extract_c
|
||||
else:
|
||||
module = module.to("cpu")
|
||||
weights = weights.to("cpu")
|
||||
continue
|
||||
|
||||
if decompose_mode == "low rank":
|
||||
loras[f"{lora_name}.lora_A.weight"] = (
|
||||
extract_a.detach().cpu().contiguous().half()
|
||||
)
|
||||
loras[f"{lora_name}.lora_B.weight"] = (
|
||||
extract_b.detach().cpu().contiguous().half()
|
||||
)
|
||||
# loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half()
|
||||
if use_bias:
|
||||
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
||||
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
||||
|
||||
indices = sparse_diff.indices().to(torch.int16)
|
||||
values = sparse_diff.values().half()
|
||||
loras[f"{lora_name}.bias_indices"] = indices
|
||||
loras[f"{lora_name}.bias_values"] = values
|
||||
loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to(
|
||||
torch.int16
|
||||
)
|
||||
del extract_a, extract_b, diff
|
||||
elif decompose_mode == "full":
|
||||
if "Norm" in layer:
|
||||
w_key = "w_norm"
|
||||
b_key = "b_norm"
|
||||
else:
|
||||
w_key = "diff"
|
||||
b_key = "diff_b"
|
||||
weight_diff = module.weight - weights.weight
|
||||
loras[f"{lora_name}.{w_key}"] = (
|
||||
weight_diff.detach().cpu().contiguous().half()
|
||||
)
|
||||
if getattr(weights, "bias", None) is not None:
|
||||
bias_diff = module.bias - weights.bias
|
||||
loras[f"{lora_name}.{b_key}"] = (
|
||||
bias_diff.detach().cpu().contiguous().half()
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
module = module.to("cpu", torch.bfloat16)
|
||||
weights = weights.to("cpu", torch.bfloat16)
|
||||
return loras
|
||||
|
||||
all_loras = {}
|
||||
|
||||
all_loras |= make_state_dict(
|
||||
LORA_PREFIX_UNET,
|
||||
base_unet,
|
||||
db_unet,
|
||||
UNET_TARGET_REPLACE_MODULE,
|
||||
)
|
||||
del base_unet, db_unet
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
all_lora_name = set()
|
||||
for k in all_loras:
|
||||
lora_name, weight = k.rsplit(".", 1)
|
||||
all_lora_name.add(lora_name)
|
||||
print(len(all_lora_name))
|
||||
return all_loras
|
||||
|
||||
|
||||
# find all the .safetensors files and load them
|
||||
print("Loading Base")
|
||||
base_model = FluxTransformer2DModel.from_pretrained(base, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
print("Loading Tuned")
|
||||
tuned_model = FluxTransformer2DModel.from_pretrained(tuned, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
output_dict = extract_diff(
|
||||
base_model,
|
||||
tuned_model,
|
||||
mode="fixed",
|
||||
linear_mode_param=dim,
|
||||
conv_mode_param=dim,
|
||||
extract_device="cuda",
|
||||
use_bias=False,
|
||||
sparsity=0.98,
|
||||
small_conv=False,
|
||||
)
|
||||
|
||||
meta = OrderedDict()
|
||||
meta['format'] = 'pt'
|
||||
|
||||
save_file(output_dict, output_path, metadata=meta)
|
||||
|
||||
print("Done")
|
||||
Reference in New Issue
Block a user