From 60ef2f1df71ec6e67ef7764d3f2acc95dc5dc429 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 26 Jun 2025 15:24:37 -0600 Subject: [PATCH] Added support for FLUX.1-Kontext-dev --- .../train_lora_flux_kontext_24gb.yaml | 106 +++++ .../diffusion_models/__init__.py | 3 +- .../diffusion_models/flux_kontext/__init__.py | 1 + .../flux_kontext/flux_kontext.py | 400 ++++++++++++++++++ requirements.txt | 2 +- toolkit/config_modules.py | 2 + ui/src/app/jobs/new/SimpleJob.tsx | 27 ++ ui/src/app/jobs/new/jobConfig.ts | 1 + ui/src/app/jobs/new/options.ts | 20 +- ui/src/docs.tsx | 9 + ui/src/types.ts | 1 + version.py | 2 +- 12 files changed, 570 insertions(+), 4 deletions(-) create mode 100644 config/examples/train_lora_flux_kontext_24gb.yaml create mode 100644 extensions_built_in/diffusion_models/flux_kontext/__init__.py create mode 100644 extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py diff --git a/config/examples/train_lora_flux_kontext_24gb.yaml b/config/examples/train_lora_flux_kontext_24gb.yaml new file mode 100644 index 00000000..d570da84 --- /dev/null +++ b/config/examples/train_lora_flux_kontext_24gb.yaml @@ -0,0 +1,106 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_flux_kontext_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + # control path is the input images for kontext for a paired dataset. These are the source images you want to change. + # You can comment this out and only use normal images if you don't have a paired dataset. + # Control images need to match the filenames on the folder path but in + # a different folder. These do not need captions. + control_path: "/path/to/control/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + # Kontext runs images in at 2x the latent size. It may OOM at 1024 resolution with 24GB vram. + resolution: [ 512, 768 ] # flux enjoys multiple resolutions + # resolution: [ 512, 768, 1024 ] + train: + batch_size: 1 + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # probably won't work with flux + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + timestep_type: "weighted" # sigmoid, linear, or weighted. + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + + # ema will smooth out learning, but could slow it down. + + # ema_config: + # use_ema: true + # ema_decay: 0.99 + + # will probably need this if gpu supports it for flux, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path. This model is gated. + # visit https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev to accept the terms and conditions + # and then you can use this model. + name_or_path: "black-forest-labs/FLUX.1-Kontext-dev" + arch: "flux_kontext" + quantize: true # run 8bit mixed precision +# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower. + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # the --ctrl_img path is the one loaded to apply the kontext editing to +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "make the person smile --ctrl_img /path/to/control/folder/person1.jpg" + - "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg" + - "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg" + - "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg" + - "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg" + - "make the person smile --ctrl_img /path/to/control/folder/person1.jpg" + - "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg" + - "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg" + - "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg" + - "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg" + neg: "" # not used on flux + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 3faa87cc..4e03acb4 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -2,8 +2,9 @@ from .chroma import ChromaModel from .hidream import HidreamModel from .f_light import FLiteModel from .omnigen2 import OmniGen2Model +from .flux_kontext import FluxKontextModel AI_TOOLKIT_MODELS = [ # put a list of models here - ChromaModel, HidreamModel, FLiteModel, OmniGen2Model + ChromaModel, HidreamModel, FLiteModel, OmniGen2Model, FluxKontextModel ] diff --git a/extensions_built_in/diffusion_models/flux_kontext/__init__.py b/extensions_built_in/diffusion_models/flux_kontext/__init__.py new file mode 100644 index 00000000..081c949a --- /dev/null +++ b/extensions_built_in/diffusion_models/flux_kontext/__init__.py @@ -0,0 +1 @@ +from .flux_kontext import FluxKontextModel diff --git a/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py new file mode 100644 index 00000000..84806017 --- /dev/null +++ b/extensions_built_in/diffusion_models/flux_kontext/flux_kontext.py @@ -0,0 +1,400 @@ +import os +from typing import TYPE_CHECKING, List + +import torch +import torchvision +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from diffusers import FluxTransformer2DModel, AutoencoderKL, FluxKontextPipeline +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.mask import generate_random_mask, random_dialate_mask +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from einops import rearrange, repeat +import random +import torch.nn.functional as F + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + + + +class FluxKontextModel(BaseModel): + arch = "flux_kontext" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['FluxTransformer2DModel'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + return 16 + + def load_model(self): + dtype = self.torch_dtype + self.print_and_status_update("Loading Flux Kontext model") + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + # this is the original path put in the model directory + # it is here because for finetuning we only save the transformer usually + # so we need this for the VAE, te, etc + base_model_path = self.model_config.extras_name_or_path + + transformer_path = model_path + transformer_subfolder = 'transformer' + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + self.print_and_status_update("Loading transformer") + transformer = FluxTransformer2DModel.from_pretrained( + transformer_path, + subfolder=transformer_subfolder, + torch_dtype=dtype, + revision="7610c9af" + ) + transformer.to(self.quantize_device, dtype=dtype) + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + base_model_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + base_model_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + self.print_and_status_update("Loading CLIP") + text_encoder = CLIPTextModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype) + + self.noise_scheduler = FluxKontextModel.get_train_scheduler() + + self.print_and_status_update("Making pipe") + + pipe: FluxKontextPipeline = FluxKontextPipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = FluxKontextModel.get_train_scheduler() + + pipeline: FluxKontextPipeline = FluxKontextPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer) + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: FluxKontextPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + if gen_config.ctrl_img is None: + raise ValueError( + "Control image is required for Flux Kontext model generation." + ) + else: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + img = pipeline( + image=control_img, + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + guidance_embedding_scale: float, + bypass_guidance_embedding: bool, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + # if we have a control on the channel dimension, put it on the batch for packing + has_control = False + if latent_model_input.shape[1] == 32: + # chunk it and stack it on batch dimension + # dont update batch size for img_its + lat, control = torch.chunk(latent_model_input, 2, dim=1) + latent_model_input = torch.cat([lat, control], dim=0) + has_control = True + + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", + b=bs).to(self.device_torch) + + # handle control image ids + if has_control: + ctrl_ids = img_ids.clone() + ctrl_ids[..., 0] = 1 + img_ids = torch.cat([img_ids, ctrl_ids], dim=1) + + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + # # handle guidance + if self.unet_unwrapped.config.guidance_embeds: + if isinstance(guidance_embedding_scale, list): + guidance = torch.tensor( + guidance_embedding_scale, device=self.device_torch) + else: + guidance = torch.tensor( + [guidance_embedding_scale], device=self.device_torch) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + if bypass_guidance_embedding: + bypass_flux_guidance(self.unet) + + cast_dtype = self.unet.dtype + # changes from orig implementation + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + latent_size = latent_model_input_packed.shape[1] + # move the kontext channels. We have them on batch dimension to here, but need to put them on the latent dimension + if has_control: + latent, control = torch.chunk(latent_model_input_packed, 2, dim=0) + latent_model_input_packed = torch.cat( + [latent, control], dim=1 + ) + latent_size = latent.shape[1] + + noise_pred = self.unet( + hidden_states=latent_model_input_packed.to( + self.device_torch, cast_dtype), + timestep=timestep / 1000, + encoder_hidden_states=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype), + pooled_projections=text_embeddings.pooled_embeds.to( + self.device_torch, cast_dtype), + txt_ids=txt_ids, + img_ids=img_ids, + guidance=guidance, + return_dict=False, + **kwargs, + )[0] + + # remove kontext image conditioning + noise_pred = noise_pred[:, :latent_size] + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=self.vae.config.latent_channels + ) + + if bypass_guidance_embedding: + restore_flux_guidance(self.unet) + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( + self.tokenizer, + self.text_encoder, + prompt, + max_length=512, + ) + pe = PromptEmbeds( + prompt_embeds + ) + pe.pooled_embeds = pooled_prompt_embeds + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return self.model.proj_out.weight.requires_grad + + def get_te_has_grad(self): + # return from a weight if it has grad + return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: FluxTransformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): + with torch.no_grad(): + control_tensor = batch.control_tensor + if control_tensor is not None: + # we are not packed here, so we just need to pass them so we can pack them later + control_tensor = control_tensor * 2 - 1 + control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]: + control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear') + + control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype) + latents = torch.cat((latents, control_latent), dim=1) + + return latents.detach() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 96ddce8e..65cc41c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git -git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b +git+https://github.com/huggingface/diffusers@00f95b9755718aabb65456e791b8408526ae6e76 transformers==4.52.4 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index aeca126e..f05d0a6e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -757,6 +757,8 @@ class DatasetConfig: self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc + if self.control_path == '': + self.control_path = None # inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will # be the part conditioned to be inpainted. The alpha 1 (visible) section will be the part that is ignored self.inpaint_path: Union[str,List[str]] = kwargs.get('inpaint_path', None) diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 0b8b3e47..4949f94a 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -101,10 +101,14 @@ export default function SimpleJob({ setJobConfig(value, 'config.process[0].model.arch'); // update controls for datasets + const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; const controls = newArch?.controls ?? []; const datasets = jobConfig.config.process[0].datasets.map(dataset => { const newDataset = objectCopy(dataset); newDataset.controls = controls; + if (!hasControlPath) { + newDataset.control_path = null; // reset control path if not applicable + } return newDataset; }); setJobConfig(datasets, 'config.process[0].datasets'); @@ -412,6 +416,17 @@ export default function SimpleJob({ onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} options={datasetOptions} /> + {modelArch?.additionalSections?.includes('datasets.control_path') && ( + + setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + )} + { + modelArch?.additionalSections?.includes('sample.ctrl_img') && ( +
+

+ Control Images +

+ To use control images on samples, add --ctrl_img to the prompts below. +
+ Example: make this a cartoon --ctrl_img /path/to/image.png +
+ ) + } {jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index dd28461b..997859d3 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -2,6 +2,7 @@ import { JobConfig, DatasetConfig } from '@/types'; export const defaultDatasetConfig: DatasetConfig = { folder_path: '/path/to/images/folder', + control_path: null, mask_path: null, mask_min_value: 0.1, default_caption: '', diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 2f4e235c..2fc51d28 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -1,5 +1,8 @@ type Control = 'depth' | 'line' | 'pose' | 'inpaint'; +type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; +type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' + export interface ModelArch { name: string; label: string; @@ -7,11 +10,11 @@ export interface ModelArch { isVideoModel?: boolean; defaults?: { [key: string]: any }; disableSections?: DisableableSections[]; + additionalSections?: AdditionalSections[]; } const defaultNameOrPath = ''; -type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; export const modelArchs: ModelArch[] = [ { @@ -27,6 +30,21 @@ export const modelArchs: ModelArch[] = [ }, disableSections: ['network.conv'], }, + { + name: 'flux_kontext', + label: 'FLUX.1-Kontext-dev', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img'], + }, { name: 'flex1', label: 'Flex.1', diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index 8450af3f..f069496c 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -48,6 +48,15 @@ const docs: { [key: string]: ConfigDoc } = { ), }, + 'datasets.control_path': { + title: 'Control Dataset', + description: ( + <> + The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs. + These images are fed as control/input images during training. + + ), + }, }; export const getDoc = (key: string | null | undefined): ConfigDoc | null => { diff --git a/ui/src/types.ts b/ui/src/types.ts index 39431e31..2a8397eb 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -83,6 +83,7 @@ export interface DatasetConfig { cache_latents_to_disk?: boolean; resolution: number[]; controls: string[]; + control_path: string | null; } export interface EMAConfig { diff --git a/version.py b/version.py index 60504c97..85b16574 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.3.2" \ No newline at end of file +VERSION = "0.3.3" \ No newline at end of file