Added support for training on primary gpu with low_vram flag. Updated example script to remove creepy horse sample at that seed

This commit is contained in:
Jaret Burkett
2024-08-11 09:54:30 -06:00
parent fa02e774b0
commit ec1ea7aa0e
4 changed files with 30 additions and 13 deletions

View File

@@ -5,6 +5,10 @@ This is my research repo. I do a lot of experiments in it and it is possible tha
If something breaks, checkout an earlier commit. This repo can train a lot of things, and it is
hard to keep up with all of them.
## Support my work
My work would not be possible without the amazing support of [Glif](https://glif.app/).
## Installation
Requirements:
@@ -43,16 +47,21 @@ pip install -r requirements.txt
### WIP. I am updating docs and optimizing as fast as I can. If there are bugs open a ticket. Not knowing how to get it to work is NOT a bug. Be paitient as I continue to develop it.
Training currently only works with FLUX.1-dev. Which means anything you train will inherit the
non commercial license. It is also a gated model, so you need to accept the license on HF before using it.
Otherwise, this will fail. Here are the required steps to setup a license.
### Requirements
You currently need a dedicated GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
your monitors, it will probably not fit as that takes up some ram. I may be able to get this lower, but for now,
It won't work. It may not work on Windows, I have only tested on linux for now. This is still extremely experimental
You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize
the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL,
but there are some reports of a bug when running on windows natively.
I have only tested on linux for now. This is still extremely experimental
and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.
### Model License
Training currently only works with FLUX.1-dev. Which means anything you train will inherit the
non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
Otherwise, this will fail. Here are the required steps to setup a license.
1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
2. Make a file named `.env` in the root on this folder
3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`

View File

@@ -25,16 +25,16 @@ config:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
# images will automatically be resized and bucketed into the resolution specified
- folder_path: "/mnt/Datasets/1920s_illustrations"
# - folder_path: "/path/to/images/folder"
- folder_path: "/path/to/images/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
shuffle_tokens: false # shuffle caption order, split by commas
cache_latents_to_disk: true # leave this true unless you know what you're doing
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
num_workers: 0
train:
batch_size: 1
steps: 4000 # total number of steps to train
steps: 4000 # total number of steps to train 500 - 4000 is a good range
gradient_accumulation_steps: 1
train_unet: true
train_text_encoder: false # probably won't work with flux
@@ -43,6 +43,8 @@ config:
noise_scheduler: "flowmatch" # for training only
optimizer: "adamw8bit"
lr: 4e-4
# uncomment this to skip the pre training sample
# skip_first_sample: true
# ema will smooth out learning, but could slow it down. Recommended to leave on.
ema_config:
@@ -56,6 +58,7 @@ config:
name_or_path: "black-forest-labs/FLUX.1-dev"
is_flux: true
quantize: true # run 8bit mixed precision
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
sample:
sampler: "flowmatch" # must match train.noise_scheduler
sample_every: 250 # sample every this many steps
@@ -66,7 +69,7 @@ config:
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
- "woman with red hair, playing chess at the park, bomb going off in the background"
- "a woman holding a coffee cup, in a beanie, sitting at a cafe"
- "a horse in a night club dancing, fish eye lens, smoke machine, lazer lights, holding a martini, large group"
- "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
- "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
- "a bear building a log cabin in the snow covered mountains"
- "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"

View File

@@ -411,6 +411,7 @@ class ModelConfig:
# only for flux for now
self.quantize = kwargs.get("quantize", False)
self.low_vram = kwargs.get("low_vram", False)
pass

View File

@@ -56,7 +56,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjecti
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
from optimum.quanto import freeze, qfloat8, quantize, QTensor
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
@@ -174,6 +174,7 @@ class StableDiffusion:
self.is_flow_matching = True
self.quantize_device = quantize_device if quantize_device is not None else self.device
self.low_vram = self.model_config.low_vram
def load_model(self):
if self.is_loaded:
@@ -472,7 +473,9 @@ class StableDiffusion:
# low_cpu_mem_usage=False,
# device_map=None
)
transformer.to(torch.device(self.quantize_device), dtype=dtype)
if not self.low_vram:
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
transformer.to(torch.device(self.quantize_device), dtype=dtype)
flush()
if self.model_config.lora_path is not None:
@@ -493,8 +496,9 @@ class StableDiffusion:
pipe.unload_lora_weights()
if self.model_config.quantize:
quantization_type = qfloat8
print("Quantizing transformer")
quantize(transformer, weights=qfloat8)
quantize(transformer, weights=quantization_type)
freeze(transformer)
transformer.to(self.device_torch)
else: