Added support for FLUX.1-Kontext-dev

This commit is contained in:
Jaret Burkett
2025-06-26 15:24:37 -06:00
parent 8d9c47316a
commit 60ef2f1df7
12 changed files with 570 additions and 4 deletions

View File

@@ -0,0 +1,106 @@
---
job: extension
config:
# this name will be the folder and filename name
name: "my_first_flux_kontext_lora_v1"
process:
- type: 'sd_trainer'
# root folder to save training sessions/samples/weights
training_folder: "output"
# uncomment to see performance stats in the terminal every N steps
# performance_log_every: 1000
device: cuda:0
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
# trigger_word: "p3r5on"
network:
type: "lora"
linear: 16
linear_alpha: 16
save:
dtype: float16 # precision to save
save_every: 250 # save every this many steps
max_step_saves_to_keep: 4 # how many intermittent saves to keep
push_to_hub: false #change this to True to push your trained model to Hugging Face.
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
# hf_repo_id: your-username/your-model-slug
# hf_private: true #whether the repo is private or public
datasets:
# datasets are a folder of images. captions need to be txt files with the same name as the image
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
# images will automatically be resized and bucketed into the resolution specified
# on windows, escape back slashes with another backslash so
# "C:\\path\\to\\images\\folder"
- folder_path: "/path/to/images/folder"
# control path is the input images for kontext for a paired dataset. These are the source images you want to change.
# You can comment this out and only use normal images if you don't have a paired dataset.
# Control images need to match the filenames on the folder path but in
# a different folder. These do not need captions.
control_path: "/path/to/control/folder"
caption_ext: "txt"
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
shuffle_tokens: false # shuffle caption order, split by commas
cache_latents_to_disk: true # leave this true unless you know what you're doing
# Kontext runs images in at 2x the latent size. It may OOM at 1024 resolution with 24GB vram.
resolution: [ 512, 768 ] # flux enjoys multiple resolutions
# resolution: [ 512, 768, 1024 ]
train:
batch_size: 1
steps: 3000 # total number of steps to train 500 - 4000 is a good range
gradient_accumulation_steps: 1
train_unet: true
train_text_encoder: false # probably won't work with flux
gradient_checkpointing: true # need the on unless you have a ton of vram
noise_scheduler: "flowmatch" # for training only
optimizer: "adamw8bit"
lr: 1e-4
timestep_type: "weighted" # sigmoid, linear, or weighted.
# uncomment this to skip the pre training sample
# skip_first_sample: true
# uncomment to completely disable sampling
# disable_sampling: true
# ema will smooth out learning, but could slow it down.
# ema_config:
# use_ema: true
# ema_decay: 0.99
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
dtype: bf16
model:
# huggingface model name or path. This model is gated.
# visit https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev to accept the terms and conditions
# and then you can use this model.
name_or_path: "black-forest-labs/FLUX.1-Kontext-dev"
arch: "flux_kontext"
quantize: true # run 8bit mixed precision
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
sample:
sampler: "flowmatch" # must match train.noise_scheduler
sample_every: 250 # sample every this many steps
width: 1024
height: 1024
prompts:
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
# the --ctrl_img path is the one loaded to apply the kontext editing to
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
- "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
- "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
- "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
- "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
- "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
- "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
- "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
- "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
- "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
- "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
neg: "" # not used on flux
seed: 42
walk_seed: true
guidance_scale: 4
sample_steps: 20
# you can add any additional meta info here. [name] is replaced with config name at top
meta:
name: "[name]"
version: '1.0'

View File

@@ -2,8 +2,9 @@ from .chroma import ChromaModel
from .hidream import HidreamModel
from .f_light import FLiteModel
from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel
AI_TOOLKIT_MODELS = [
# put a list of models here
ChromaModel, HidreamModel, FLiteModel, OmniGen2Model
ChromaModel, HidreamModel, FLiteModel, OmniGen2Model, FluxKontextModel
]

View File

@@ -0,0 +1 @@
from .flux_kontext import FluxKontextModel

View File

@@ -0,0 +1,400 @@
import os
from typing import TYPE_CHECKING, List
import torch
import torchvision
import yaml
from toolkit import train_tools
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from PIL import Image
from toolkit.models.base_model import BaseModel
from diffusers import FluxTransformer2DModel, AutoencoderKL, FluxKontextPipeline
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.accelerator import get_accelerator, unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.mask import generate_random_mask, random_dialate_mask
from toolkit.util.quantize import quantize, get_qtype
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
from einops import rearrange, repeat
import random
import torch.nn.functional as F
if TYPE_CHECKING:
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": 0.5,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 3.0,
"use_dynamic_shifting": True
}
class FluxKontextModel(BaseModel):
arch = "flux_kontext"
def __init__(
self,
device,
model_config: ModelConfig,
dtype='bf16',
custom_pipeline=None,
noise_scheduler=None,
**kwargs
):
super().__init__(
device,
model_config,
dtype,
custom_pipeline,
noise_scheduler,
**kwargs
)
self.is_flow_matching = True
self.is_transformer = True
self.target_lora_modules = ['FluxTransformer2DModel']
# static method to get the noise scheduler
@staticmethod
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
def get_bucket_divisibility(self):
return 16
def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading Flux Kontext model")
# will be updated if we detect a existing checkpoint in training folder
model_path = self.model_config.name_or_path
# this is the original path put in the model directory
# it is here because for finetuning we only save the transformer usually
# so we need this for the VAE, te, etc
base_model_path = self.model_config.extras_name_or_path
transformer_path = model_path
transformer_subfolder = 'transformer'
if os.path.exists(transformer_path):
transformer_subfolder = None
transformer_path = os.path.join(transformer_path, 'transformer')
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, 'text_encoder')
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path
self.print_and_status_update("Loading transformer")
transformer = FluxTransformer2DModel.from_pretrained(
transformer_path,
subfolder=transformer_subfolder,
torch_dtype=dtype,
revision="7610c9af"
)
transformer.to(self.quantize_device, dtype=dtype)
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = get_qtype(self.model_config.qtype)
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type,
**self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
self.print_and_status_update("Loading T5")
tokenizer_2 = T5TokenizerFast.from_pretrained(
base_model_path, subfolder="tokenizer_2", torch_dtype=dtype
)
text_encoder_2 = T5EncoderModel.from_pretrained(
base_model_path, subfolder="text_encoder_2", torch_dtype=dtype
)
text_encoder_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing T5")
quantize(text_encoder_2, weights=get_qtype(
self.model_config.qtype))
freeze(text_encoder_2)
flush()
self.print_and_status_update("Loading CLIP")
text_encoder = CLIPTextModel.from_pretrained(
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(
base_model_path, subfolder="vae", torch_dtype=dtype)
self.noise_scheduler = FluxKontextModel.get_train_scheduler()
self.print_and_status_update("Making pipe")
pipe: FluxKontextPipeline = FluxKontextPipeline(
scheduler=self.noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
# for quantization, it works best to do these after making the pipe
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
text_encoder[1].to(self.device_torch)
text_encoder[1].requires_grad_(False)
text_encoder[1].eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# save it to the model class
self.vae = vae
self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer
self.pipeline = pipe
self.print_and_status_update("Model Loaded")
def get_generation_pipeline(self):
scheduler = FluxKontextModel.get_train_scheduler()
pipeline: FluxKontextPipeline = FluxKontextPipeline(
scheduler=scheduler,
text_encoder=unwrap_model(self.text_encoder[0]),
tokenizer=self.tokenizer[0],
text_encoder_2=unwrap_model(self.text_encoder[1]),
tokenizer_2=self.tokenizer[1],
vae=unwrap_model(self.vae),
transformer=unwrap_model(self.transformer)
)
pipeline = pipeline.to(self.device_torch)
return pipeline
def generate_single_image(
self,
pipeline: FluxKontextPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
if gen_config.ctrl_img is None:
raise ValueError(
"Control image is required for Flux Kontext model generation."
)
else:
control_img = Image.open(gen_config.ctrl_img)
control_img = control_img.convert("RGB")
img = pipeline(
image=control_img,
prompt_embeds=conditional_embeds.text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
generator=generator,
**extra
).images[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
guidance_embedding_scale: float,
bypass_guidance_embedding: bool,
**kwargs
):
with torch.no_grad():
bs, c, h, w = latent_model_input.shape
# if we have a control on the channel dimension, put it on the batch for packing
has_control = False
if latent_model_input.shape[1] == 32:
# chunk it and stack it on batch dimension
# dont update batch size for img_its
lat, control = torch.chunk(latent_model_input, 2, dim=1)
latent_model_input = torch.cat([lat, control], dim=0)
has_control = True
latent_model_input_packed = rearrange(
latent_model_input,
"b c (h ph) (w pw) -> b (h w) (c ph pw)",
ph=2,
pw=2
)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c",
b=bs).to(self.device_torch)
# handle control image ids
if has_control:
ctrl_ids = img_ids.clone()
ctrl_ids[..., 0] = 1
img_ids = torch.cat([img_ids, ctrl_ids], dim=1)
txt_ids = torch.zeros(
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
# # handle guidance
if self.unet_unwrapped.config.guidance_embeds:
if isinstance(guidance_embedding_scale, list):
guidance = torch.tensor(
guidance_embedding_scale, device=self.device_torch)
else:
guidance = torch.tensor(
[guidance_embedding_scale], device=self.device_torch)
guidance = guidance.expand(latent_model_input.shape[0])
else:
guidance = None
if bypass_guidance_embedding:
bypass_flux_guidance(self.unet)
cast_dtype = self.unet.dtype
# changes from orig implementation
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
img_ids = img_ids[0]
latent_size = latent_model_input_packed.shape[1]
# move the kontext channels. We have them on batch dimension to here, but need to put them on the latent dimension
if has_control:
latent, control = torch.chunk(latent_model_input_packed, 2, dim=0)
latent_model_input_packed = torch.cat(
[latent, control], dim=1
)
latent_size = latent.shape[1]
noise_pred = self.unet(
hidden_states=latent_model_input_packed.to(
self.device_torch, cast_dtype),
timestep=timestep / 1000,
encoder_hidden_states=text_embeddings.text_embeds.to(
self.device_torch, cast_dtype),
pooled_projections=text_embeddings.pooled_embeds.to(
self.device_torch, cast_dtype),
txt_ids=txt_ids,
img_ids=img_ids,
guidance=guidance,
return_dict=False,
**kwargs,
)[0]
# remove kontext image conditioning
noise_pred = noise_pred[:, :latent_size]
if isinstance(noise_pred, QTensor):
noise_pred = noise_pred.dequantize()
noise_pred = rearrange(
noise_pred,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=latent_model_input.shape[2] // 2,
w=latent_model_input.shape[3] // 2,
ph=2,
pw=2,
c=self.vae.config.latent_channels
)
if bypass_guidance_embedding:
restore_flux_guidance(self.unet)
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
if self.pipeline.text_encoder.device != self.device_torch:
self.pipeline.text_encoder.to(self.device_torch)
prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux(
self.tokenizer,
self.text_encoder,
prompt,
max_length=512,
)
pe = PromptEmbeds(
prompt_embeds
)
pe.pooled_embeds = pooled_prompt_embeds
return pe
def get_model_has_grad(self):
# return from a weight if it has grad
return self.model.proj_out.weight.requires_grad
def get_te_has_grad(self):
# return from a weight if it has grad
return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
def save_model(self, output_path, meta, save_dtype):
# only save the unet
transformer: FluxTransformer2DModel = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
safe_serialization=True,
)
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
batch = kwargs.get('batch')
return (noise - batch.latents).detach()
def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'):
with torch.no_grad():
control_tensor = batch.control_tensor
if control_tensor is not None:
# we are not packed here, so we just need to pass them so we can pack them later
control_tensor = control_tensor * 2 - 1
control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype)
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear')
control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype)
latents = torch.cat((latents, control_latent), dim=1)
return latents.detach()

View File

@@ -1,7 +1,7 @@
torchao==0.10.0
safetensors
git+https://github.com/jaretburkett/easy_dwpose.git
git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b
git+https://github.com/huggingface/diffusers@00f95b9755718aabb65456e791b8408526ae6e76
transformers==4.52.4
lycoris-lora==1.8.3
flatten_json

View File

@@ -757,6 +757,8 @@ class DatasetConfig:
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
if self.control_path == '':
self.control_path = None
# inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will
# be the part conditioned to be inpainted. The alpha 1 (visible) section will be the part that is ignored
self.inpaint_path: Union[str,List[str]] = kwargs.get('inpaint_path', None)

View File

@@ -101,10 +101,14 @@ export default function SimpleJob({
setJobConfig(value, 'config.process[0].model.arch');
// update controls for datasets
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
const controls = newArch?.controls ?? [];
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
const newDataset = objectCopy(dataset);
newDataset.controls = controls;
if (!hasControlPath) {
newDataset.control_path = null; // reset control path if not applicable
}
return newDataset;
});
setJobConfig(datasets, 'config.process[0].datasets');
@@ -412,6 +416,17 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
options={datasetOptions}
/>
{modelArch?.additionalSections?.includes('datasets.control_path') && (
<SelectInput
label="Control Dataset"
docKey="datasets.control_path"
value={dataset.control_path ?? ''}
onChange={value =>
setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
}
options={[{ value: '', label: <>&nbsp;</> }, ...datasetOptions]}
/>
)}
<NumberInput
label="LoRA Weight"
value={dataset.network_weight}
@@ -604,6 +619,18 @@ export default function SimpleJob({
)}
</div>
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
{
modelArch?.additionalSections?.includes('sample.ctrl_img') && (
<div className='text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg'>
<p className='font-semibold mb-1'>
Control Images
</p>
To use control images on samples, add --ctrl_img to the prompts below.
<br />
Example: <code className='bg-yellow-900 p-1'>make this a cartoon --ctrl_img /path/to/image.png</code>
</div>
)
}
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
<div key={i} className="flex items-center space-x-2">
<div className="flex-1">

View File

@@ -2,6 +2,7 @@ import { JobConfig, DatasetConfig } from '@/types';
export const defaultDatasetConfig: DatasetConfig = {
folder_path: '/path/to/images/folder',
control_path: null,
mask_path: null,
mask_min_value: 0.1,
default_caption: '',

View File

@@ -1,5 +1,8 @@
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img'
export interface ModelArch {
name: string;
label: string;
@@ -7,11 +10,11 @@ export interface ModelArch {
isVideoModel?: boolean;
defaults?: { [key: string]: any };
disableSections?: DisableableSections[];
additionalSections?: AdditionalSections[];
}
const defaultNameOrPath = '';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
export const modelArchs: ModelArch[] = [
{
@@ -27,6 +30,21 @@ export const modelArchs: ModelArch[] = [
},
disableSections: ['network.conv'],
},
{
name: 'flux_kontext',
label: 'FLUX.1-Kontext-dev',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
},
{
name: 'flex1',
label: 'Flex.1',

View File

@@ -48,6 +48,15 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'datasets.control_path': {
title: 'Control Dataset',
description: (
<>
The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs.
These images are fed as control/input images during training.
</>
),
},
};
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

View File

@@ -83,6 +83,7 @@ export interface DatasetConfig {
cache_latents_to_disk?: boolean;
resolution: number[];
controls: string[];
control_path: string | null;
}
export interface EMAConfig {

View File

@@ -1 +1 @@
VERSION = "0.3.2"
VERSION = "0.3.3"