mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-30 18:39:50 +00:00
Added support for FLUX.1-Kontext-dev
This commit is contained in:
106
config/examples/train_lora_flux_kontext_24gb.yaml
Normal file
106
config/examples/train_lora_flux_kontext_24gb.yaml
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
# this name will be the folder and filename name
|
||||
name: "my_first_flux_kontext_lora_v1"
|
||||
process:
|
||||
- type: 'sd_trainer'
|
||||
# root folder to save training sessions/samples/weights
|
||||
training_folder: "output"
|
||||
# uncomment to see performance stats in the terminal every N steps
|
||||
# performance_log_every: 1000
|
||||
device: cuda:0
|
||||
# if a trigger word is specified, it will be added to captions of training data if it does not already exist
|
||||
# alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
|
||||
# trigger_word: "p3r5on"
|
||||
network:
|
||||
type: "lora"
|
||||
linear: 16
|
||||
linear_alpha: 16
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 250 # save every this many steps
|
||||
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
||||
push_to_hub: false #change this to True to push your trained model to Hugging Face.
|
||||
# You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
|
||||
# hf_repo_id: your-username/your-model-slug
|
||||
# hf_private: true #whether the repo is private or public
|
||||
datasets:
|
||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||
# images will automatically be resized and bucketed into the resolution specified
|
||||
# on windows, escape back slashes with another backslash so
|
||||
# "C:\\path\\to\\images\\folder"
|
||||
- folder_path: "/path/to/images/folder"
|
||||
# control path is the input images for kontext for a paired dataset. These are the source images you want to change.
|
||||
# You can comment this out and only use normal images if you don't have a paired dataset.
|
||||
# Control images need to match the filenames on the folder path but in
|
||||
# a different folder. These do not need captions.
|
||||
control_path: "/path/to/control/folder"
|
||||
caption_ext: "txt"
|
||||
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
||||
shuffle_tokens: false # shuffle caption order, split by commas
|
||||
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
||||
# Kontext runs images in at 2x the latent size. It may OOM at 1024 resolution with 24GB vram.
|
||||
resolution: [ 512, 768 ] # flux enjoys multiple resolutions
|
||||
# resolution: [ 512, 768, 1024 ]
|
||||
train:
|
||||
batch_size: 1
|
||||
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
||||
gradient_accumulation_steps: 1
|
||||
train_unet: true
|
||||
train_text_encoder: false # probably won't work with flux
|
||||
gradient_checkpointing: true # need the on unless you have a ton of vram
|
||||
noise_scheduler: "flowmatch" # for training only
|
||||
optimizer: "adamw8bit"
|
||||
lr: 1e-4
|
||||
timestep_type: "weighted" # sigmoid, linear, or weighted.
|
||||
# uncomment this to skip the pre training sample
|
||||
# skip_first_sample: true
|
||||
# uncomment to completely disable sampling
|
||||
# disable_sampling: true
|
||||
|
||||
# ema will smooth out learning, but could slow it down.
|
||||
|
||||
# ema_config:
|
||||
# use_ema: true
|
||||
# ema_decay: 0.99
|
||||
|
||||
# will probably need this if gpu supports it for flux, other dtypes may not work correctly
|
||||
dtype: bf16
|
||||
model:
|
||||
# huggingface model name or path. This model is gated.
|
||||
# visit https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev to accept the terms and conditions
|
||||
# and then you can use this model.
|
||||
name_or_path: "black-forest-labs/FLUX.1-Kontext-dev"
|
||||
arch: "flux_kontext"
|
||||
quantize: true # run 8bit mixed precision
|
||||
# low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
|
||||
sample:
|
||||
sampler: "flowmatch" # must match train.noise_scheduler
|
||||
sample_every: 250 # sample every this many steps
|
||||
width: 1024
|
||||
height: 1024
|
||||
prompts:
|
||||
# you can add [trigger] to the prompts here and it will be replaced with the trigger word
|
||||
# the --ctrl_img path is the one loaded to apply the kontext editing to
|
||||
# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
|
||||
- "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "make the person smile --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "give the person an afro --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "turn this image into a cartoon --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "put this person in an action film --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
- "make this person a rapper in a rap music video --ctrl_img /path/to/control/folder/person1.jpg"
|
||||
neg: "" # not used on flux
|
||||
seed: 42
|
||||
walk_seed: true
|
||||
guidance_scale: 4
|
||||
sample_steps: 20
|
||||
# you can add any additional meta info here. [name] is replaced with config name at top
|
||||
meta:
|
||||
name: "[name]"
|
||||
version: '1.0'
|
||||
@@ -2,8 +2,9 @@ from .chroma import ChromaModel
|
||||
from .hidream import HidreamModel
|
||||
from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
ChromaModel, HidreamModel, FLiteModel, OmniGen2Model
|
||||
ChromaModel, HidreamModel, FLiteModel, OmniGen2Model, FluxKontextModel
|
||||
]
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .flux_kontext import FluxKontextModel
|
||||
@@ -0,0 +1,400 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import yaml
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from PIL import Image
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from diffusers import FluxTransformer2DModel, AutoencoderKL, FluxKontextPipeline
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.mask import generate_random_mask, random_dialate_mask
|
||||
from toolkit.util.quantize import quantize, get_qtype
|
||||
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
|
||||
from einops import rearrange, repeat
|
||||
import random
|
||||
import torch.nn.functional as F
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.5,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 1.15,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 3.0,
|
||||
"use_dynamic_shifting": True
|
||||
}
|
||||
|
||||
|
||||
|
||||
class FluxKontextModel(BaseModel):
|
||||
arch = "flux_kontext"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype='bf16',
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
device,
|
||||
model_config,
|
||||
dtype,
|
||||
custom_pipeline,
|
||||
noise_scheduler,
|
||||
**kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ['FluxTransformer2DModel']
|
||||
|
||||
# static method to get the noise scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 16
|
||||
|
||||
def load_model(self):
|
||||
dtype = self.torch_dtype
|
||||
self.print_and_status_update("Loading Flux Kontext model")
|
||||
# will be updated if we detect a existing checkpoint in training folder
|
||||
model_path = self.model_config.name_or_path
|
||||
# this is the original path put in the model directory
|
||||
# it is here because for finetuning we only save the transformer usually
|
||||
# so we need this for the VAE, te, etc
|
||||
base_model_path = self.model_config.extras_name_or_path
|
||||
|
||||
transformer_path = model_path
|
||||
transformer_subfolder = 'transformer'
|
||||
if os.path.exists(transformer_path):
|
||||
transformer_subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, 'text_encoder')
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=transformer_subfolder,
|
||||
torch_dtype=dtype,
|
||||
revision="7610c9af"
|
||||
)
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading T5")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(
|
||||
base_model_path, subfolder="tokenizer_2", torch_dtype=dtype
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
base_model_path, subfolder="text_encoder_2", torch_dtype=dtype
|
||||
)
|
||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
self.print_and_status_update("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=get_qtype(
|
||||
self.model_config.qtype))
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading CLIP")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
|
||||
self.noise_scheduler = FluxKontextModel.get_train_scheduler()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
pipe: FluxKontextPipeline = FluxKontextPipeline(
|
||||
scheduler=self.noise_scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=None,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
)
|
||||
# for quantization, it works best to do these after making the pipe
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.transformer = transformer
|
||||
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
# just to make sure everything is on the right device and dtype
|
||||
text_encoder[0].to(self.device_torch)
|
||||
text_encoder[0].requires_grad_(False)
|
||||
text_encoder[0].eval()
|
||||
text_encoder[1].to(self.device_torch)
|
||||
text_encoder[1].requires_grad_(False)
|
||||
text_encoder[1].eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder # list of text encoders
|
||||
self.tokenizer = tokenizer # list of tokenizers
|
||||
self.model = pipe.transformer
|
||||
self.pipeline = pipe
|
||||
self.print_and_status_update("Model Loaded")
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = FluxKontextModel.get_train_scheduler()
|
||||
|
||||
pipeline: FluxKontextPipeline = FluxKontextPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
text_encoder_2=unwrap_model(self.text_encoder[1]),
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
vae=unwrap_model(self.vae),
|
||||
transformer=unwrap_model(self.transformer)
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: FluxKontextPipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
if gen_config.ctrl_img is None:
|
||||
raise ValueError(
|
||||
"Control image is required for Flux Kontext model generation."
|
||||
)
|
||||
else:
|
||||
control_img = Image.open(gen_config.ctrl_img)
|
||||
control_img = control_img.convert("RGB")
|
||||
img = pipeline(
|
||||
image=control_img,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
**extra
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
guidance_embedding_scale: float,
|
||||
bypass_guidance_embedding: bool,
|
||||
**kwargs
|
||||
):
|
||||
with torch.no_grad():
|
||||
bs, c, h, w = latent_model_input.shape
|
||||
# if we have a control on the channel dimension, put it on the batch for packing
|
||||
has_control = False
|
||||
if latent_model_input.shape[1] == 32:
|
||||
# chunk it and stack it on batch dimension
|
||||
# dont update batch size for img_its
|
||||
lat, control = torch.chunk(latent_model_input, 2, dim=1)
|
||||
latent_model_input = torch.cat([lat, control], dim=0)
|
||||
has_control = True
|
||||
|
||||
latent_model_input_packed = rearrange(
|
||||
latent_model_input,
|
||||
"b c (h ph) (w pw) -> b (h w) (c ph pw)",
|
||||
ph=2,
|
||||
pw=2
|
||||
)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c",
|
||||
b=bs).to(self.device_torch)
|
||||
|
||||
# handle control image ids
|
||||
if has_control:
|
||||
ctrl_ids = img_ids.clone()
|
||||
ctrl_ids[..., 0] = 1
|
||||
img_ids = torch.cat([img_ids, ctrl_ids], dim=1)
|
||||
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||
|
||||
# # handle guidance
|
||||
if self.unet_unwrapped.config.guidance_embeds:
|
||||
if isinstance(guidance_embedding_scale, list):
|
||||
guidance = torch.tensor(
|
||||
guidance_embedding_scale, device=self.device_torch)
|
||||
else:
|
||||
guidance = torch.tensor(
|
||||
[guidance_embedding_scale], device=self.device_torch)
|
||||
guidance = guidance.expand(latent_model_input.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
bypass_flux_guidance(self.unet)
|
||||
|
||||
cast_dtype = self.unet.dtype
|
||||
# changes from orig implementation
|
||||
if txt_ids.ndim == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
img_ids = img_ids[0]
|
||||
|
||||
latent_size = latent_model_input_packed.shape[1]
|
||||
# move the kontext channels. We have them on batch dimension to here, but need to put them on the latent dimension
|
||||
if has_control:
|
||||
latent, control = torch.chunk(latent_model_input_packed, 2, dim=0)
|
||||
latent_model_input_packed = torch.cat(
|
||||
[latent, control], dim=1
|
||||
)
|
||||
latent_size = latent.shape[1]
|
||||
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input_packed.to(
|
||||
self.device_torch, cast_dtype),
|
||||
timestep=timestep / 1000,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(
|
||||
self.device_torch, cast_dtype),
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(
|
||||
self.device_torch, cast_dtype),
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
guidance=guidance,
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
# remove kontext image conditioning
|
||||
noise_pred = noise_pred[:, :latent_size]
|
||||
|
||||
if isinstance(noise_pred, QTensor):
|
||||
noise_pred = noise_pred.dequantize()
|
||||
|
||||
noise_pred = rearrange(
|
||||
noise_pred,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=latent_model_input.shape[2] // 2,
|
||||
w=latent_model_input.shape[3] // 2,
|
||||
ph=2,
|
||||
pw=2,
|
||||
c=self.vae.config.latent_channels
|
||||
)
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
restore_flux_guidance(self.unet)
|
||||
|
||||
return noise_pred
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux(
|
||||
self.tokenizer,
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
max_length=512,
|
||||
)
|
||||
pe = PromptEmbeds(
|
||||
prompt_embeds
|
||||
)
|
||||
pe.pooled_embeds = pooled_prompt_embeds
|
||||
return pe
|
||||
|
||||
def get_model_has_grad(self):
|
||||
# return from a weight if it has grad
|
||||
return self.model.proj_out.weight.requires_grad
|
||||
|
||||
def get_te_has_grad(self):
|
||||
# return from a weight if it has grad
|
||||
return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = unwrap_model(self.model)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_path, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def get_loss_target(self, *args, **kwargs):
|
||||
noise = kwargs.get('noise')
|
||||
batch = kwargs.get('batch')
|
||||
return (noise - batch.latents).detach()
|
||||
|
||||
def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
control_tensor = batch.control_tensor
|
||||
if control_tensor is not None:
|
||||
# we are not packed here, so we just need to pass them so we can pack them later
|
||||
control_tensor = control_tensor * 2 - 1
|
||||
control_tensor = control_tensor.to(self.vae_device_torch, dtype=self.torch_dtype)
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if control_tensor.shape[2] != batch.tensor.shape[2] or control_tensor.shape[3] != batch.tensor.shape[3]:
|
||||
control_tensor = F.interpolate(control_tensor, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear')
|
||||
|
||||
control_latent = self.encode_images(control_tensor).to(latents.device, latents.dtype)
|
||||
latents = torch.cat((latents, control_latent), dim=1)
|
||||
|
||||
return latents.detach()
|
||||
@@ -1,7 +1,7 @@
|
||||
torchao==0.10.0
|
||||
safetensors
|
||||
git+https://github.com/jaretburkett/easy_dwpose.git
|
||||
git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b
|
||||
git+https://github.com/huggingface/diffusers@00f95b9755718aabb65456e791b8408526ae6e76
|
||||
transformers==4.52.4
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -757,6 +757,8 @@ class DatasetConfig:
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc
|
||||
if self.control_path == '':
|
||||
self.control_path = None
|
||||
# inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will
|
||||
# be the part conditioned to be inpainted. The alpha 1 (visible) section will be the part that is ignored
|
||||
self.inpaint_path: Union[str,List[str]] = kwargs.get('inpaint_path', None)
|
||||
|
||||
@@ -101,10 +101,14 @@ export default function SimpleJob({
|
||||
setJobConfig(value, 'config.process[0].model.arch');
|
||||
|
||||
// update controls for datasets
|
||||
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
||||
const controls = newArch?.controls ?? [];
|
||||
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
||||
const newDataset = objectCopy(dataset);
|
||||
newDataset.controls = controls;
|
||||
if (!hasControlPath) {
|
||||
newDataset.control_path = null; // reset control path if not applicable
|
||||
}
|
||||
return newDataset;
|
||||
});
|
||||
setJobConfig(datasets, 'config.process[0].datasets');
|
||||
@@ -412,6 +416,17 @@ export default function SimpleJob({
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
||||
options={datasetOptions}
|
||||
/>
|
||||
{modelArch?.additionalSections?.includes('datasets.control_path') && (
|
||||
<SelectInput
|
||||
label="Control Dataset"
|
||||
docKey="datasets.control_path"
|
||||
value={dataset.control_path ?? ''}
|
||||
onChange={value =>
|
||||
setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
|
||||
}
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
)}
|
||||
<NumberInput
|
||||
label="LoRA Weight"
|
||||
value={dataset.network_weight}
|
||||
@@ -604,6 +619,18 @@ export default function SimpleJob({
|
||||
)}
|
||||
</div>
|
||||
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
|
||||
{
|
||||
modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
<div className='text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg'>
|
||||
<p className='font-semibold mb-1'>
|
||||
Control Images
|
||||
</p>
|
||||
To use control images on samples, add --ctrl_img to the prompts below.
|
||||
<br />
|
||||
Example: <code className='bg-yellow-900 p-1'>make this a cartoon --ctrl_img /path/to/image.png</code>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
|
||||
<div key={i} className="flex items-center space-x-2">
|
||||
<div className="flex-1">
|
||||
|
||||
@@ -2,6 +2,7 @@ import { JobConfig, DatasetConfig } from '@/types';
|
||||
|
||||
export const defaultDatasetConfig: DatasetConfig = {
|
||||
folder_path: '/path/to/images/folder',
|
||||
control_path: null,
|
||||
mask_path: null,
|
||||
mask_min_value: 0.1,
|
||||
default_caption: '',
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img'
|
||||
|
||||
export interface ModelArch {
|
||||
name: string;
|
||||
label: string;
|
||||
@@ -7,11 +10,11 @@ export interface ModelArch {
|
||||
isVideoModel?: boolean;
|
||||
defaults?: { [key: string]: any };
|
||||
disableSections?: DisableableSections[];
|
||||
additionalSections?: AdditionalSections[];
|
||||
}
|
||||
|
||||
const defaultNameOrPath = '';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
|
||||
export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
@@ -27,6 +30,21 @@ export const modelArchs: ModelArch[] = [
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
},
|
||||
{
|
||||
name: 'flux_kontext',
|
||||
label: 'FLUX.1-Kontext-dev',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.control_path', 'sample.ctrl_img'],
|
||||
},
|
||||
{
|
||||
name: 'flex1',
|
||||
label: 'Flex.1',
|
||||
|
||||
@@ -48,6 +48,15 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.control_path': {
|
||||
title: 'Control Dataset',
|
||||
description: (
|
||||
<>
|
||||
The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs.
|
||||
These images are fed as control/input images during training.
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
|
||||
@@ -83,6 +83,7 @@ export interface DatasetConfig {
|
||||
cache_latents_to_disk?: boolean;
|
||||
resolution: number[];
|
||||
controls: string[];
|
||||
control_path: string | null;
|
||||
}
|
||||
|
||||
export interface EMAConfig {
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.3.2"
|
||||
VERSION = "0.3.3"
|
||||
Reference in New Issue
Block a user