diff --git a/config/examples/train_full_fine_tune_flex.yaml b/config/examples/train_full_fine_tune_flex.yaml new file mode 100644 index 00000000..e449bb8d --- /dev/null +++ b/config/examples/train_full_fine_tune_flex.yaml @@ -0,0 +1,103 @@ +--- +# This configuration requires 48GB of VRAM or more to operate +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flex_finetune_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps + # performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # trigger_word: "p3r5on" + save: + dtype: bf16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 2 # how many intermittent saves to keep + save_format: 'diffusers' # 'diffusers' + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + # cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions + train: + batch_size: 1 + # IMPORTANT! For Flex, you must bypass the guidance embedder during training + bypass_guidance_embedding: true + + # can be 'sigmoid', 'linear', or 'lognorm_blend' + timestep_type: 'sigmoid' + + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with flex + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adafactor" + lr: 3e-5 + + # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0. + # 0.1 is 10% of paramiters active at easc step. Only works with adafactor + + # do_paramiter_swapping: true + # paramiter_swapping_factor: 0.9 + + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flex, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "ostris/Flex.1-alpha" + is_flux: true # flex is flux architecture + # full finetuning quantized models is a crapshoot and results in subpar outputs + # quantize: true + # you can quantize just the T5 text encoder here to save vram + quantize_te: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flex + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/scripts/extract_lora_from_flex.py b/scripts/extract_lora_from_flex.py new file mode 100644 index 00000000..908c84e3 --- /dev/null +++ b/scripts/extract_lora_from_flex.py @@ -0,0 +1,244 @@ +import os +from tqdm import tqdm +import argparse +from collections import OrderedDict + +parser = argparse.ArgumentParser(description="Extract LoRA from Flex") +parser.add_argument("--base", type=str, default="ostris/Flex.1-alpha", help="Base model path") +parser.add_argument("--tuned", type=str, required=True, help="Tuned model path") +parser.add_argument("--output", type=str, required=True, help="Output path for lora") +parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction") +parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction") + +args = parser.parse_args() + +if True: + # set cuda environment variable + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + import torch + from safetensors.torch import load_file, save_file + from lycoris.utils import extract_linear, extract_conv, make_sparse + from diffusers import FluxTransformer2DModel + +base = args.base +tuned = args.tuned +output_path = args.output +dim = args.rank + +os.makedirs(os.path.dirname(output_path), exist_ok=True) + +state_dict_base = {} +state_dict_tuned = {} + +output_dict = {} + +@torch.no_grad() +def extract_diff( + base_unet, + db_unet, + mode="fixed", + linear_mode_param=0, + conv_mode_param=0, + extract_device="cpu", + use_bias=False, + sparsity=0.98, + # small_conv=True, + small_conv=False, +): + UNET_TARGET_REPLACE_MODULE = [ + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + "LoRACompatibleLinear", + "LoRACompatibleConv" + ] + LORA_PREFIX_UNET = "transformer" + + def make_state_dict( + prefix, + root_module: torch.nn.Module, + target_module: torch.nn.Module, + target_replace_modules, + ): + loras = {} + temp = {} + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + temp[name] = module + + for name, module in tqdm( + list((n, m) for n, m in target_module.named_modules() if n in temp) + ): + weights = temp[name] + lora_name = prefix + "." + name + # lora_name = lora_name.replace(".", "_") + layer = module.__class__.__name__ + if 'transformer_blocks' not in lora_name: + continue + + if layer in { + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + "Embedding", + "LoRACompatibleLinear", + "LoRACompatibleConv" + }: + root_weight = module.weight + try: + if torch.allclose(root_weight, weights.weight): + continue + except: + continue + else: + continue + module = module.to(extract_device, torch.float32) + weights = weights.to(extract_device, torch.float32) + + if mode == "full": + decompose_mode = "full" + elif layer == "Linear": + weight, decompose_mode = extract_linear( + (root_weight - weights.weight), + mode, + linear_mode_param, + device=extract_device, + ) + if decompose_mode == "low rank": + extract_a, extract_b, diff = weight + elif layer == "Conv2d": + is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 + weight, decompose_mode = extract_conv( + (root_weight - weights.weight), + mode, + linear_mode_param if is_linear else conv_mode_param, + device=extract_device, + ) + if decompose_mode == "low rank": + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == "low rank": + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + "fixed", + dim, + extract_device, + True, + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f"{lora_name}.lora_mid.weight"] = ( + extract_c.detach().cpu().contiguous().half() + ) + diff = ( + ( + root_weight + - torch.einsum( + "i j k l, j r, p i -> p r k l", + extract_c, + extract_a.flatten(1, -1), + extract_b.flatten(1, -1), + ) + ) + .detach() + .cpu() + .contiguous() + ) + del extract_c + else: + module = module.to("cpu") + weights = weights.to("cpu") + continue + + if decompose_mode == "low rank": + loras[f"{lora_name}.lora_A.weight"] = ( + extract_a.detach().cpu().contiguous().half() + ) + loras[f"{lora_name}.lora_B.weight"] = ( + extract_b.detach().cpu().contiguous().half() + ) + # loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f"{lora_name}.bias_indices"] = indices + loras[f"{lora_name}.bias_values"] = values + loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( + torch.int16 + ) + del extract_a, extract_b, diff + elif decompose_mode == "full": + if "Norm" in layer: + w_key = "w_norm" + b_key = "b_norm" + else: + w_key = "diff" + b_key = "diff_b" + weight_diff = module.weight - weights.weight + loras[f"{lora_name}.{w_key}"] = ( + weight_diff.detach().cpu().contiguous().half() + ) + if getattr(weights, "bias", None) is not None: + bias_diff = module.bias - weights.bias + loras[f"{lora_name}.{b_key}"] = ( + bias_diff.detach().cpu().contiguous().half() + ) + else: + raise NotImplementedError + module = module.to("cpu", torch.bfloat16) + weights = weights.to("cpu", torch.bfloat16) + return loras + + all_loras = {} + + all_loras |= make_state_dict( + LORA_PREFIX_UNET, + base_unet, + db_unet, + UNET_TARGET_REPLACE_MODULE, + ) + del base_unet, db_unet + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + all_lora_name = set() + for k in all_loras: + lora_name, weight = k.rsplit(".", 1) + all_lora_name.add(lora_name) + print(len(all_lora_name)) + return all_loras + + +# find all the .safetensors files and load them +print("Loading Base") +base_model = FluxTransformer2DModel.from_pretrained(base, subfolder="transformer", torch_dtype=torch.bfloat16) + +print("Loading Tuned") +tuned_model = FluxTransformer2DModel.from_pretrained(tuned, subfolder="transformer", torch_dtype=torch.bfloat16) + +output_dict = extract_diff( + base_model, + tuned_model, + mode="fixed", + linear_mode_param=dim, + conv_mode_param=dim, + extract_device="cuda", + use_bias=False, + sparsity=0.98, + small_conv=False, +) + +meta = OrderedDict() +meta['format'] = 'pt' + +save_file(output_dict, output_path, metadata=meta) + +print("Done")