Added support for training on flux schnell. Added example config and instructions for training on flux schnell

This commit is contained in:
Jaret Burkett
2024-08-17 06:58:39 -06:00
parent f9179540d2
commit 81899310f8
4 changed files with 144 additions and 9 deletions

View File

@@ -64,9 +64,9 @@ but there are some reports of a bug when running on windows natively.
I have only tested on linux for now. This is still extremely experimental
and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.
### Model License
### FLUX.1-dev
Training currently only works with FLUX.1-dev. Which means anything you train will inherit the
FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the
non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
Otherwise, this will fail. Here are the required steps to setup a license.
@@ -74,10 +74,34 @@ Otherwise, this will fail. Here are the required steps to setup a license.
2. Make a file named `.env` in the root on this folder
3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`
### FLUX.1-schnell
FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train.
However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter).
It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended.
To use it, You just need to add the assistant to the `model` section of your config file like so:
```yaml
model:
name_or_path: "black-forest-labs/FLUX.1-schnell"
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter"
is_flux: true
quantize: true
```
You also need to adjust your sample steps since schnell does not require as many
```yaml
sample:
guidance_scale: 1 # schnell does not do guidance
sample_steps: 4 # 1 - 4 works well
```
### Training
1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` to the `config` folder and rename it to `whatever_you_want.yml`
1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
2. Edit the file following the comments in the file
3. Run the file like so `python3 run.py config/whatever_you_want.yml`
3. Run the file like so `python run.py config/whatever_you_want.yml`
A folder with the name and the training folder from the config file will be created when you start. It will have all
checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up

View File

@@ -48,7 +48,7 @@ config:
# uncomment to completely disable sampling
# disable_sampling: true
# uncomment to use new vell curved weighting. Experimental but may produce better results
linear_timesteps: true
# linear_timesteps: true
# ema will smooth out learning, but could slow it down. Recommended to leave on.
ema_config:

View File

@@ -0,0 +1,94 @@
---
job: extension
config:
# this name will be the folder and filename name
name: "my_first_flux_lora_v1"
process:
- type: 'sd_trainer'
# root folder to save training sessions/samples/weights
training_folder: "output"
# uncomment to see performance stats in the terminal every N steps
# performance_log_every: 1000
device: cuda:0
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
# trigger_word: "p3r5on"
network:
type: "lora"
linear: 16
linear_alpha: 16
save:
dtype: float16 # precision to save
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
# images will automatically be resized and bucketed into the resolution specified
# on windows, escape back slashes with another backslash so
# "C:\\path\\to\\images\\folder"
- folder_path: "/path/to/images/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
shuffle_tokens: false # shuffle caption order, split by commas
cache_latents_to_disk: true # leave this true unless you know what you're doing
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
train:
batch_size: 1
steps: 2000 # total number of steps to train 500 - 4000 is a good range
gradient_accumulation_steps: 1
train_unet: true
train_text_encoder: false # probably won't work with flux
gradient_checkpointing: true # need the on unless you have a ton of vram
noise_scheduler: "flowmatch" # for training only
optimizer: "adamw8bit"
lr: 1e-4
# uncomment this to skip the pre training sample
# skip_first_sample: true
# uncomment to completely disable sampling
# disable_sampling: true
# uncomment to use new bell curved weighting. Experimental but may produce better results
# linear_timesteps: true
# ema will smooth out learning, but could slow it down. Recommended to leave on.
ema_config:
use_ema: true
ema_decay: 0.99
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
dtype: bf16
model:
# huggingface model name or path
name_or_path: "black-forest-labs/FLUX.1-schnell"
assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
is_flux: true
quantize: true # run 8bit mixed precision
# low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
sample:
sampler: "flowmatch" # must match train.noise_scheduler
sample_every: 250 # sample every this many steps
width: 1024
height: 1024
prompts:
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
- "woman with red hair, playing chess at the park, bomb going off in the background"
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
- "a bear building a log cabin in the snow covered mountains"
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
- "hipster man with a beard, building a chair, in a wood shop"
- "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
- "a man holding a sign that says, 'this is a sign'"
- "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
neg: "" # not used on flux
seed: 42
walk_seed: true
guidance_scale: 1 # schnell does not do guidance
sample_steps: 4 # 1 - 4 works well
# you can add any additional meta info here. [name] is replaced with config name at top
meta:
name: "[name]"
version: '1.0'

View File

@@ -56,7 +56,7 @@ from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
from huggingface_hub import hf_hub_download
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
from typing import TYPE_CHECKING
@@ -496,10 +496,23 @@ class StableDiffusion:
transformer.to(torch.device(self.quantize_device), dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None and self.model_config.lora_path:
raise ValueError("Cannot load both assistant lora and lora at the same time")
if self.model_config.assistant_lora_path is not None:
if self.model_config.lora_path:
raise ValueError("Cannot load both assistant lora and lora at the same time")
if not self.is_flux:
raise ValueError("Assistant lora is only supported for flux models currently")
# handle downloading from the hub if needed
if not os.path.exists(self.model_config.assistant_lora_path):
print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}")
new_lora_path = hf_hub_download(
self.model_config.assistant_lora_path,
filename="pytorch_lora_weights.safetensors"
)
# replace the path
self.model_config.assistant_lora_path = new_lora_path
if self.model_config.assistant_lora_path is not None and self.is_flux:
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
@@ -509,6 +522,10 @@ class StableDiffusion:
self.model_config.lora_path = self.model_config.assistant_lora_path
if self.model_config.lora_path is not None:
print("Fusing in LoRA")
# if doing low vram, do this on the gpu, painfully slow otherwise
if self.low_vram:
print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.")
# need the pipe to do this unfortunately for now
# we have to fuse in the weights before quantizing
pipe: FluxPipeline = FluxPipeline(