From 81899310f877e6788f94b96facab1e775fc5c373 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 17 Aug 2024 06:58:39 -0600 Subject: [PATCH] Added support for training on flux schnell. Added example config and instructions for training on flux schnell --- README.md | 32 ++++++- config/examples/train_lora_flux_24gb.yaml | 2 +- .../train_lora_flux_schnell_24gb.yaml | 94 +++++++++++++++++++ toolkit/stable_diffusion_model.py | 25 ++++- 4 files changed, 144 insertions(+), 9 deletions(-) create mode 100644 config/examples/train_lora_flux_schnell_24gb.yaml diff --git a/README.md b/README.md index bcab1f2b..45779dff 100644 --- a/README.md +++ b/README.md @@ -64,9 +64,9 @@ but there are some reports of a bug when running on windows natively. I have only tested on linux for now. This is still extremely experimental and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all. -### Model License +### FLUX.1-dev -Training currently only works with FLUX.1-dev. Which means anything you train will inherit the +FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the non-commercial license. It is also a gated model, so you need to accept the license on HF before using it. Otherwise, this will fail. Here are the required steps to setup a license. @@ -74,10 +74,34 @@ Otherwise, this will fail. Here are the required steps to setup a license. 2. Make a file named `.env` in the root on this folder 3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here` +### FLUX.1-schnell + +FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train. +However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter). +It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended. + +To use it, You just need to add the assistant to the `model` section of your config file like so: + +```yaml + model: + name_or_path: "black-forest-labs/FLUX.1-schnell" + assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" + is_flux: true + quantize: true +``` + +You also need to adjust your sample steps since schnell does not require as many + +```yaml + sample: + guidance_scale: 1 # schnell does not do guidance + sample_steps: 4 # 1 - 4 works well +``` + ### Training -1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` to the `config` folder and rename it to `whatever_you_want.yml` +1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml` 2. Edit the file following the comments in the file -3. Run the file like so `python3 run.py config/whatever_you_want.yml` +3. Run the file like so `python run.py config/whatever_you_want.yml` A folder with the name and the training folder from the config file will be created when you start. It will have all checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up diff --git a/config/examples/train_lora_flux_24gb.yaml b/config/examples/train_lora_flux_24gb.yaml index fe0d4c87..c3fd119a 100644 --- a/config/examples/train_lora_flux_24gb.yaml +++ b/config/examples/train_lora_flux_24gb.yaml @@ -48,7 +48,7 @@ config: # uncomment to completely disable sampling # disable_sampling: true # uncomment to use new vell curved weighting. Experimental but may produce better results - linear_timesteps: true +# linear_timesteps: true # ema will smooth out learning, but could slow it down. Recommended to leave on. ema_config: diff --git a/config/examples/train_lora_flux_schnell_24gb.yaml b/config/examples/train_lora_flux_schnell_24gb.yaml new file mode 100644 index 00000000..c6ef95e4 --- /dev/null +++ b/config/examples/train_lora_flux_schnell_24gb.yaml @@ -0,0 +1,94 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new bell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "black-forest-labs/FLUX.1-schnell" + assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training + is_flux: true + quantize: true # run 8bit mixed precision + # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 1 # schnell does not do guidance + sample_steps: 4 # 1 - 4 works well +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 47ac104b..852dcebc 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -56,7 +56,7 @@ from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT -from toolkit.util.inverse_cfg import inverse_classifier_guidance +from huggingface_hub import hf_hub_download from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 from typing import TYPE_CHECKING @@ -496,10 +496,23 @@ class StableDiffusion: transformer.to(torch.device(self.quantize_device), dtype=dtype) flush() - if self.model_config.assistant_lora_path is not None and self.model_config.lora_path: - raise ValueError("Cannot load both assistant lora and lora at the same time") + if self.model_config.assistant_lora_path is not None: + if self.model_config.lora_path: + raise ValueError("Cannot load both assistant lora and lora at the same time") + + if not self.is_flux: + raise ValueError("Assistant lora is only supported for flux models currently") + + # handle downloading from the hub if needed + if not os.path.exists(self.model_config.assistant_lora_path): + print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}") + new_lora_path = hf_hub_download( + self.model_config.assistant_lora_path, + filename="pytorch_lora_weights.safetensors" + ) + # replace the path + self.model_config.assistant_lora_path = new_lora_path - if self.model_config.assistant_lora_path is not None and self.is_flux: # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half @@ -509,6 +522,10 @@ class StableDiffusion: self.model_config.lora_path = self.model_config.assistant_lora_path if self.model_config.lora_path is not None: + print("Fusing in LoRA") + # if doing low vram, do this on the gpu, painfully slow otherwise + if self.low_vram: + print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.") # need the pipe to do this unfortunately for now # we have to fuse in the weights before quantizing pipe: FluxPipeline = FluxPipeline(