mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-10 15:39:57 +00:00
Merge pull request #434 from ostris/qwen_image_edit_plus
Add full support for Qwen-Image-Edit-2509
This commit is contained in:
105
config/examples/train_lora_qwen_image_edit_2509_32gb.yaml
Normal file
105
config/examples/train_lora_qwen_image_edit_2509_32gb.yaml
Normal file
@@ -0,0 +1,105 @@
|
||||
---
|
||||
job: extension
|
||||
config:
|
||||
# this name will be the folder and filename name
|
||||
name: "my_first_qwen_image_edit_2509_lora_v1"
|
||||
process:
|
||||
- type: 'diffusion_trainer'
|
||||
# root folder to save training sessions/samples/weights
|
||||
training_folder: "output"
|
||||
# uncomment to see performance stats in the terminal every N steps
|
||||
# performance_log_every: 1000
|
||||
device: cuda:0
|
||||
network:
|
||||
type: "lora"
|
||||
linear: 16
|
||||
linear_alpha: 16
|
||||
save:
|
||||
dtype: float16 # precision to save
|
||||
save_every: 250 # save every this many steps
|
||||
max_step_saves_to_keep: 4 # how many intermittent saves to keep
|
||||
datasets:
|
||||
# datasets are a folder of images. captions need to be txt files with the same name as the image
|
||||
# for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
|
||||
# images will automatically be resized and bucketed into the resolution specified
|
||||
# on windows, escape back slashes with another backslash so
|
||||
# "C:\\path\\to\\images\\folder"
|
||||
- folder_path: "/path/to/images/folder"
|
||||
# can do up to 3 control image folders, file names must match target file names, but aspect/size can be different
|
||||
control_path:
|
||||
- "/path/to/control/images/folder1"
|
||||
- "/path/to/control/images/folder2"
|
||||
- "/path/to/control/images/folder3"
|
||||
caption_ext: "txt"
|
||||
# default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached
|
||||
caption_dropout_rate: 0.05 # will drop out the caption 5% of time
|
||||
resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions
|
||||
# a trigger word that can be cached with the text embeddings
|
||||
# trigger_word: "optional trigger word"
|
||||
train:
|
||||
batch_size: 1
|
||||
# caching text embeddings is required for 32GB
|
||||
cache_text_embeddings: true
|
||||
# unload_text_encoder: true
|
||||
|
||||
steps: 3000 # total number of steps to train 500 - 4000 is a good range
|
||||
gradient_accumulation: 1
|
||||
timestep_type: "weighted"
|
||||
train_unet: true
|
||||
train_text_encoder: false # probably won't work with qwen image
|
||||
gradient_checkpointing: true # need the on unless you have a ton of vram
|
||||
noise_scheduler: "flowmatch" # for training only
|
||||
optimizer: "adamw8bit"
|
||||
lr: 1e-4
|
||||
# uncomment this to skip the pre training sample
|
||||
# skip_first_sample: true
|
||||
# uncomment to completely disable sampling
|
||||
# disable_sampling: true
|
||||
dtype: bf16
|
||||
model:
|
||||
# huggingface model name or path
|
||||
name_or_path: "Qwen/Qwen-Image-Edit-2509"
|
||||
arch: "qwen_image_edit_plus"
|
||||
quantize: true
|
||||
# to use the ARA use the | pipe to point to hf path, or a local path if you have one.
|
||||
# 3bit is required for 32GB
|
||||
qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors"
|
||||
quantize_te: true
|
||||
qtype_te: "qfloat8"
|
||||
low_vram: true
|
||||
sample:
|
||||
sampler: "flowmatch" # must match train.noise_scheduler
|
||||
sample_every: 250 # sample every this many steps
|
||||
width: 1024
|
||||
height: 1024
|
||||
# you can provide up to 3 control images here
|
||||
samples:
|
||||
- prompt: "Do whatever with Image1 and Image2"
|
||||
ctrl_img_1: "/path/to/image1.png"
|
||||
ctrl_img_2: "/path/to/image2.png"
|
||||
# ctrl_img_3: "/path/to/image3.png"
|
||||
- prompt: "Do whatever with Image1 and Image2"
|
||||
ctrl_img_1: "/path/to/image1.png"
|
||||
ctrl_img_2: "/path/to/image2.png"
|
||||
# ctrl_img_3: "/path/to/image3.png"
|
||||
- prompt: "Do whatever with Image1 and Image2"
|
||||
ctrl_img_1: "/path/to/image1.png"
|
||||
ctrl_img_2: "/path/to/image2.png"
|
||||
# ctrl_img_3: "/path/to/image3.png"
|
||||
- prompt: "Do whatever with Image1 and Image2"
|
||||
ctrl_img_1: "/path/to/image1.png"
|
||||
ctrl_img_2: "/path/to/image2.png"
|
||||
# ctrl_img_3: "/path/to/image3.png"
|
||||
- prompt: "Do whatever with Image1 and Image2"
|
||||
ctrl_img_1: "/path/to/image1.png"
|
||||
ctrl_img_2: "/path/to/image2.png"
|
||||
# ctrl_img_3: "/path/to/image3.png"
|
||||
neg: ""
|
||||
seed: 42
|
||||
walk_seed: true
|
||||
guidance_scale: 3
|
||||
sample_steps: 25
|
||||
# you can add any additional meta info here. [name] is replaced with config name at top
|
||||
meta:
|
||||
name: "[name]"
|
||||
version: '1.0'
|
||||
@@ -4,7 +4,7 @@ from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel
|
||||
from .qwen_image import QwenImageModel, QwenImageEditModel
|
||||
from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
# put a list of models here
|
||||
@@ -20,4 +20,5 @@ AI_TOOLKIT_MODELS = [
|
||||
Wan2214bModel,
|
||||
QwenImageModel,
|
||||
QwenImageEditModel,
|
||||
QwenImageEditPlusModel,
|
||||
]
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .qwen_image import QwenImageModel
|
||||
from .qwen_image_edit import QwenImageEditModel
|
||||
from .qwen_image_edit import QwenImageEditModel
|
||||
from .qwen_image_edit_plus import QwenImageEditPlusModel
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
import math
|
||||
import torch
|
||||
from .qwen_image import QwenImageModel
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
import yaml
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from PIL import Image
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from optimum.quanto import freeze, QTensor
|
||||
from toolkit.util.quantize import quantize, get_qtype, quantize_model
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers import (
|
||||
QwenImageTransformer2DModel,
|
||||
AutoencoderKLQwenImage,
|
||||
)
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
|
||||
try:
|
||||
from diffusers import QwenImageEditPlusPipeline
|
||||
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, VAE_IMAGE_SIZE
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'"
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusModel(QwenImageModel):
|
||||
arch = "qwen_image_edit_plus"
|
||||
_qwen_image_keep_visual = True
|
||||
_qwen_pipeline = QwenImageEditPlusPipeline
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
|
||||
)
|
||||
self.is_flow_matching = True
|
||||
self.is_transformer = True
|
||||
self.target_lora_modules = ["QwenImageTransformer2DModel"]
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = True
|
||||
# control images will come in as a list for encoding some things if true
|
||||
self.has_multiple_control_images = True
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = True
|
||||
|
||||
def load_model(self):
|
||||
super().load_model()
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = QwenImageModel.get_train_scheduler()
|
||||
|
||||
pipeline: QwenImageEditPlusPipeline = QwenImageEditPlusPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
processor=self.processor,
|
||||
vae=unwrap_model(self.vae),
|
||||
transformer=unwrap_model(self.transformer),
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: QwenImageEditPlusPipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
self.model.to(self.device_torch, dtype=self.torch_dtype)
|
||||
sc = self.get_bucket_divisibility()
|
||||
gen_config.width = int(gen_config.width // sc * sc)
|
||||
gen_config.height = int(gen_config.height // sc * sc)
|
||||
|
||||
control_img_list = []
|
||||
if gen_config.ctrl_img is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img)
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
elif gen_config.ctrl_img_1 is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img_1)
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
|
||||
if gen_config.ctrl_img_2 is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img_2)
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
if gen_config.ctrl_img_3 is not None:
|
||||
control_img = Image.open(gen_config.ctrl_img_3)
|
||||
control_img = control_img.convert("RGB")
|
||||
control_img_list.append(control_img)
|
||||
|
||||
# flush for low vram if we are doing that
|
||||
# flush_between_steps = self.model_config.low_vram
|
||||
flush_between_steps = False
|
||||
|
||||
# Fix a bug in diffusers/torch
|
||||
def callback_on_step_end(pipe, i, t, callback_kwargs):
|
||||
if flush_between_steps:
|
||||
flush()
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
return {"latents": latents}
|
||||
|
||||
img = pipeline(
|
||||
image=control_img_list,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_embeds_mask=conditional_embeds.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
true_cfg_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
**extra,
|
||||
).images[0]
|
||||
return img
|
||||
|
||||
def condition_noisy_latents(
|
||||
self, latents: torch.Tensor, batch: "DataLoaderBatchDTO"
|
||||
):
|
||||
# we get the control image from the batch
|
||||
return latents.detach()
|
||||
|
||||
def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds:
|
||||
# todo handle not caching text encoder
|
||||
if self.pipeline.text_encoder.device != self.device_torch:
|
||||
self.pipeline.text_encoder.to(self.device_torch)
|
||||
|
||||
if control_images is not None and len(control_images) > 0:
|
||||
for i in range(len(control_images)):
|
||||
# control images are 0 - 1 scale, shape (bs, ch, height, width)
|
||||
ratio = control_images[i].shape[2] / control_images[i].shape[3]
|
||||
width = math.sqrt(CONDITION_IMAGE_SIZE * ratio)
|
||||
height = width / ratio
|
||||
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
|
||||
control_images[i] = F.interpolate(
|
||||
control_images[i], size=(height, width), mode="bilinear"
|
||||
)
|
||||
|
||||
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
image=control_images,
|
||||
device=self.device_torch,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
pe = PromptEmbeds(prompt_embeds)
|
||||
pe.attention_mask = prompt_embeds_mask
|
||||
return pe
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
batch: "DataLoaderBatchDTO" = None,
|
||||
**kwargs,
|
||||
):
|
||||
with torch.no_grad():
|
||||
batch_size, num_channels_latents, height, width = latent_model_input.shape
|
||||
|
||||
# pack image tokens
|
||||
latent_model_input = latent_model_input.view(
|
||||
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
||||
)
|
||||
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
|
||||
latent_model_input = latent_model_input.reshape(
|
||||
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
||||
)
|
||||
|
||||
raw_packed_latents = latent_model_input
|
||||
|
||||
img_h2, img_w2 = height // 2, width // 2
|
||||
|
||||
img_shapes = [
|
||||
[(1, img_h2, img_w2)]
|
||||
] * batch_size
|
||||
|
||||
# pack controls
|
||||
if batch is None:
|
||||
raise ValueError("Batch is required for QwenImageEditPlusModel")
|
||||
|
||||
# split the latents into batch items so we can concat the controls
|
||||
packed_latents_list = torch.chunk(latent_model_input, batch_size, dim=0)
|
||||
packed_latents_with_controls_list = []
|
||||
|
||||
if batch.control_tensor_list is not None:
|
||||
if len(batch.control_tensor_list) != batch_size:
|
||||
raise ValueError("Control tensor list length does not match batch size")
|
||||
b = 0
|
||||
for control_tensor_list in batch.control_tensor_list:
|
||||
# control tensor list is a list of tensors for this batch item
|
||||
controls = []
|
||||
# pack control
|
||||
for control_img in control_tensor_list:
|
||||
# control images are 0 - 1 scale, shape (1, ch, height, width)
|
||||
control_img = control_img.to(self.device_torch, dtype=self.torch_dtype)
|
||||
# if it is only 3 dim, add batch dim
|
||||
if len(control_img.shape) == 3:
|
||||
control_img = control_img.unsqueeze(0)
|
||||
ratio = control_img.shape[2] / control_img.shape[3]
|
||||
c_width = math.sqrt(VAE_IMAGE_SIZE * ratio)
|
||||
c_height = c_width / ratio
|
||||
|
||||
c_width = round(c_width / 32) * 32
|
||||
c_height = round(c_height / 32) * 32
|
||||
|
||||
control_img = F.interpolate(
|
||||
control_img, size=(c_height, c_width), mode="bilinear"
|
||||
)
|
||||
|
||||
control_latent = self.encode_images(
|
||||
control_img,
|
||||
device=self.device_torch,
|
||||
dtype=self.torch_dtype,
|
||||
)
|
||||
|
||||
clb, cl_num_channels_latents, cl_height, cl_width = control_latent.shape
|
||||
|
||||
control = control_latent.view(
|
||||
1, cl_num_channels_latents, cl_height // 2, 2, cl_width // 2, 2
|
||||
)
|
||||
control = control.permute(0, 2, 4, 1, 3, 5)
|
||||
control = control.reshape(
|
||||
1, (cl_height // 2) * (cl_width // 2), num_channels_latents * 4
|
||||
)
|
||||
|
||||
img_shapes[b].append((1, cl_height // 2, cl_width // 2))
|
||||
controls.append(control)
|
||||
|
||||
# stack controls on dim 1
|
||||
control = torch.cat(controls, dim=1).to(packed_latents_list[b].device, dtype=packed_latents_list[b].dtype)
|
||||
# concat with latents
|
||||
packed_latents_with_control = torch.cat([packed_latents_list[b], control], dim=1)
|
||||
|
||||
packed_latents_with_controls_list.append(packed_latents_with_control)
|
||||
|
||||
b += 1
|
||||
|
||||
latent_model_input = torch.cat(packed_latents_with_controls_list, dim=0)
|
||||
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
)
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
|
||||
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(
|
||||
self.device_torch, dtype=torch.int64
|
||||
)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=enc_hs,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : raw_packed_latents.size(1)]
|
||||
|
||||
# unpack
|
||||
noise_pred = noise_pred.view(
|
||||
batch_size, height // 2, width // 2, num_channels_latents, 2, 2
|
||||
)
|
||||
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
|
||||
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
|
||||
return noise_pred
|
||||
@@ -129,17 +129,64 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
negative_prompt=sample_item.neg,
|
||||
output_path=output_path,
|
||||
ctrl_img=sample_item.ctrl_img
|
||||
ctrl_img=sample_item.ctrl_img,
|
||||
ctrl_img_1=sample_item.ctrl_img_1,
|
||||
ctrl_img_2=sample_item.ctrl_img_2,
|
||||
ctrl_img_3=sample_item.ctrl_img_3,
|
||||
)
|
||||
|
||||
has_control_images = False
|
||||
if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None:
|
||||
has_control_images = True
|
||||
# see if we need to encode the control images
|
||||
if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None:
|
||||
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
if self.sd.encode_control_in_text_embeddings and has_control_images:
|
||||
|
||||
ctrl_img_list = []
|
||||
|
||||
if gen_img_config.ctrl_img is not None:
|
||||
ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img)
|
||||
|
||||
if gen_img_config.ctrl_img_1 is not None:
|
||||
ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_1 = (
|
||||
TF.to_tensor(ctrl_img_1)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_1)
|
||||
if gen_img_config.ctrl_img_2 is not None:
|
||||
ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_2 = (
|
||||
TF.to_tensor(ctrl_img_2)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_2)
|
||||
if gen_img_config.ctrl_img_3 is not None:
|
||||
ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img_3 = (
|
||||
TF.to_tensor(ctrl_img_3)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(ctrl_img_3)
|
||||
|
||||
if self.sd.has_multiple_control_images:
|
||||
ctrl_img = ctrl_img_list
|
||||
else:
|
||||
ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None
|
||||
|
||||
|
||||
positive = self.sd.encode_prompt(
|
||||
gen_img_config.prompt,
|
||||
control_images=ctrl_img
|
||||
@@ -202,6 +249,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.sd.encode_control_in_text_embeddings:
|
||||
# just do a blank image for unconditionals
|
||||
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
if self.sd.has_multiple_control_images:
|
||||
control_image = [control_image]
|
||||
|
||||
kwargs['control_images'] = control_image
|
||||
self.unconditional_embeds = self.sd.encode_prompt(
|
||||
[self.train_config.unconditional_prompt],
|
||||
@@ -272,6 +322,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.sd.encode_control_in_text_embeddings:
|
||||
# just do a blank image for unconditionals
|
||||
control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
if self.sd.has_multiple_control_images:
|
||||
control_image = [control_image]
|
||||
encode_kwargs['control_images'] = control_image
|
||||
self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs)
|
||||
if self.trigger_word is not None:
|
||||
|
||||
@@ -348,6 +348,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
fps=sample_item.fps,
|
||||
ctrl_img=sample_item.ctrl_img,
|
||||
ctrl_idx=sample_item.ctrl_idx,
|
||||
ctrl_img_1=sample_item.ctrl_img_1,
|
||||
ctrl_img_2=sample_item.ctrl_img_2,
|
||||
ctrl_img_3=sample_item.ctrl_img_3,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
torchao==0.10.0
|
||||
safetensors
|
||||
git+https://github.com/jaretburkett/easy_dwpose.git
|
||||
git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63
|
||||
git+https://github.com/huggingface/diffusers@1448b035859dd57bbb565239dcdd79a025a85422
|
||||
transformers==4.52.4
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
|
||||
@@ -56,6 +56,11 @@ class SampleItem:
|
||||
self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames)
|
||||
self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None)
|
||||
self.ctrl_idx: int = kwargs.get('ctrl_idx', 0)
|
||||
# for multi control image models
|
||||
self.ctrl_img_1: Optional[str] = kwargs.get('ctrl_img_1', self.ctrl_img)
|
||||
self.ctrl_img_2: Optional[str] = kwargs.get('ctrl_img_2', None)
|
||||
self.ctrl_img_3: Optional[str] = kwargs.get('ctrl_img_3', None)
|
||||
|
||||
self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier)
|
||||
# convert to a number if it is a string
|
||||
if isinstance(self.network_multiplier, str):
|
||||
@@ -826,6 +831,21 @@ class DatasetConfig:
|
||||
if self.control_path == '':
|
||||
self.control_path = None
|
||||
|
||||
# handle multi control inputs from the ui. It is just easier to handle it here for a cleaner ui experience
|
||||
control_path_1 = kwargs.get('control_path_1', None)
|
||||
control_path_2 = kwargs.get('control_path_2', None)
|
||||
control_path_3 = kwargs.get('control_path_3', None)
|
||||
|
||||
if any([control_path_1, control_path_2, control_path_3]):
|
||||
control_paths = []
|
||||
if control_path_1:
|
||||
control_paths.append(control_path_1)
|
||||
if control_path_2:
|
||||
control_paths.append(control_path_2)
|
||||
if control_path_3:
|
||||
control_paths.append(control_path_3)
|
||||
self.control_path = control_paths
|
||||
|
||||
# color for transparent reigon of control images with transparency
|
||||
self.control_transparent_color: List[int] = kwargs.get('control_transparent_color', [0, 0, 0])
|
||||
# inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will
|
||||
@@ -966,6 +986,9 @@ class GenerateImageConfig:
|
||||
extra_values: List[float] = None, # extra values to save with prompt file
|
||||
logger: Optional[EmptyLogger] = None,
|
||||
ctrl_img: Optional[str] = None, # control image for controlnet
|
||||
ctrl_img_1: Optional[str] = None, # first control image for multi control model
|
||||
ctrl_img_2: Optional[str] = None, # second control image for multi control model
|
||||
ctrl_img_3: Optional[str] = None, # third control image for multi control model
|
||||
num_frames: int = 1,
|
||||
fps: int = 15,
|
||||
ctrl_idx: int = 0
|
||||
@@ -1002,6 +1025,12 @@ class GenerateImageConfig:
|
||||
self.ctrl_img = ctrl_img
|
||||
self.ctrl_idx = ctrl_idx
|
||||
|
||||
if ctrl_img_1 is None and ctrl_img is not None:
|
||||
ctrl_img_1 = ctrl_img
|
||||
|
||||
self.ctrl_img_1 = ctrl_img_1
|
||||
self.ctrl_img_2 = ctrl_img_2
|
||||
self.ctrl_img_3 = ctrl_img_3
|
||||
|
||||
# prompt string will override any settings above
|
||||
self._process_prompt_string()
|
||||
|
||||
@@ -144,6 +144,7 @@ class DataLoaderBatchDTO:
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -160,7 +161,6 @@ class DataLoaderBatchDTO:
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
@@ -178,6 +178,16 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
# handle control tensor list
|
||||
if any([x.control_tensor_list is not None for x in self.file_items]):
|
||||
self.control_tensor_list = []
|
||||
for x in self.file_items:
|
||||
if x.control_tensor_list is not None:
|
||||
self.control_tensor_list.append(x.control_tensor_list)
|
||||
else:
|
||||
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}")
|
||||
|
||||
|
||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||
if any([x.inpaint_tensor is not None for x in self.file_items]):
|
||||
|
||||
@@ -850,6 +850,9 @@ class ControlFileItemDTOMixin:
|
||||
self.has_control_image = False
|
||||
self.control_path: Union[str, List[str], None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.control_tensor_list: Union[List[torch.Tensor], None] = None
|
||||
sd = kwargs.get('sd', None)
|
||||
self.use_raw_control_images = sd is not None and sd.use_raw_control_images
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.full_size_control_images = False
|
||||
if dataset_config.control_path is not None:
|
||||
@@ -900,23 +903,14 @@ class ControlFileItemDTOMixin:
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading image: {control_path}")
|
||||
|
||||
|
||||
if not self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
else:
|
||||
elif not self.use_raw_control_images:
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
@@ -950,11 +944,15 @@ class ControlFileItemDTOMixin:
|
||||
self.control_tensor = None
|
||||
elif len(control_tensors) == 1:
|
||||
self.control_tensor = control_tensors[0]
|
||||
elif self.use_raw_control_images:
|
||||
# just send the list of tensors as their shapes wont match
|
||||
self.control_tensor_list = control_tensors
|
||||
else:
|
||||
self.control_tensor = torch.stack(control_tensors, dim=0)
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.control_tensor = None
|
||||
self.control_tensor_list = None
|
||||
|
||||
|
||||
class ClipImageFileItemDTOMixin:
|
||||
@@ -1884,14 +1882,31 @@ class TextEmbeddingCachingMixin:
|
||||
if file_item.encode_control_in_text_embeddings:
|
||||
if file_item.control_path is None:
|
||||
raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model")
|
||||
# load the control image and feed it into the text encoder
|
||||
ctrl_img = Image.open(file_item.control_path).convert("RGB")
|
||||
# convert to 0 to 1 tensor
|
||||
ctrl_img = (
|
||||
TF.to_tensor(ctrl_img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list = []
|
||||
control_path_list = file_item.control_path
|
||||
if not isinstance(file_item.control_path, list):
|
||||
control_path_list = [control_path_list]
|
||||
for i in range(len(control_path_list)):
|
||||
try:
|
||||
img = Image.open(control_path_list[i]).convert("RGB")
|
||||
img = exif_transpose(img)
|
||||
# convert to 0 to 1 tensor
|
||||
img = (
|
||||
TF.to_tensor(img)
|
||||
.unsqueeze(0)
|
||||
.to(self.sd.device_torch, dtype=self.sd.torch_dtype)
|
||||
)
|
||||
ctrl_img_list.append(img)
|
||||
except Exception as e:
|
||||
print_acc(f"Error: {e}")
|
||||
print_acc(f"Error loading control image: {control_path_list[i]}")
|
||||
|
||||
if len(ctrl_img_list) == 0:
|
||||
ctrl_img = None
|
||||
elif not self.sd.has_multiple_control_images:
|
||||
ctrl_img = ctrl_img_list[0]
|
||||
else:
|
||||
ctrl_img = ctrl_img_list
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img)
|
||||
else:
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
|
||||
@@ -181,6 +181,10 @@ class BaseModel:
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
# control images will come in as a list for encoding some things if true
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -219,6 +219,10 @@ class StableDiffusion:
|
||||
|
||||
# set true for models that encode control image into text embeddings
|
||||
self.encode_control_in_text_embeddings = False
|
||||
# control images will come in as a list for encoding some things if true
|
||||
self.has_multiple_control_images = False
|
||||
# do not resize control images
|
||||
self.use_raw_control_images = False
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
|
||||
@@ -15,7 +15,9 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp
|
||||
import Card from '@/components/Card';
|
||||
import { X } from 'lucide-react';
|
||||
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
|
||||
import SampleControlImage from '@/components/SampleControlImage';
|
||||
import { FlipHorizontal2, FlipVertical2 } from 'lucide-react';
|
||||
import { handleModelArchChange } from './utils';
|
||||
|
||||
type Props = {
|
||||
jobConfig: JobConfig;
|
||||
@@ -185,58 +187,7 @@ export default function SimpleJob({
|
||||
label="Model Architecture"
|
||||
value={jobConfig.config.process[0].model.arch}
|
||||
onChange={value => {
|
||||
const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch);
|
||||
if (!currentArch || currentArch.name === value) {
|
||||
return;
|
||||
}
|
||||
// update the defaults when a model is selected
|
||||
const newArch = modelArchs.find(model => model.name === value);
|
||||
|
||||
// update vram setting
|
||||
if (!newArch?.additionalSections?.includes('model.low_vram')) {
|
||||
setJobConfig(false, 'config.process[0].model.low_vram');
|
||||
}
|
||||
|
||||
// revert defaults from previous model
|
||||
for (const key in currentArch.defaults) {
|
||||
setJobConfig(currentArch.defaults[key][1], key);
|
||||
}
|
||||
|
||||
if (newArch?.defaults) {
|
||||
for (const key in newArch.defaults) {
|
||||
setJobConfig(newArch.defaults[key][0], key);
|
||||
}
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.arch');
|
||||
|
||||
// update datasets
|
||||
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
||||
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
|
||||
const controls = newArch?.controls ?? [];
|
||||
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
||||
const newDataset = objectCopy(dataset);
|
||||
newDataset.controls = controls;
|
||||
if (!hasControlPath) {
|
||||
newDataset.control_path = null; // reset control path if not applicable
|
||||
}
|
||||
if (!hasNumFrames) {
|
||||
newDataset.num_frames = 1; // reset num_frames if not applicable
|
||||
}
|
||||
return newDataset;
|
||||
});
|
||||
setJobConfig(datasets, 'config.process[0].datasets');
|
||||
|
||||
// update samples
|
||||
const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false;
|
||||
const samples = jobConfig.config.process[0].sample.samples.map(sample => {
|
||||
const newSample = objectCopy(sample);
|
||||
if (!hasSampleCtrlImg) {
|
||||
delete newSample.ctrl_img; // remove ctrl_img if not applicable
|
||||
}
|
||||
return newSample;
|
||||
});
|
||||
setJobConfig(samples, 'config.process[0].sample.samples');
|
||||
handleModelArchChange(jobConfig.config.process[0].model.arch, value, jobConfig, setJobConfig);
|
||||
}}
|
||||
options={groupedModelOptions}
|
||||
/>
|
||||
@@ -557,17 +508,19 @@ export default function SimpleJob({
|
||||
)}
|
||||
|
||||
<FormGroup label="Text Encoder Optimizations" className="pt-2">
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
docKey={'train.unload_text_encoder'}
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder');
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{!disableSections.includes('train.unload_text_encoder') && (
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
docKey={'train.unload_text_encoder'}
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder');
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<Checkbox
|
||||
label="Cache Text Embeddings"
|
||||
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
|
||||
@@ -642,7 +595,7 @@ export default function SimpleJob({
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Dataset"
|
||||
label="Target Dataset"
|
||||
value={dataset.folder_path}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
||||
options={datasetOptions}
|
||||
@@ -659,6 +612,49 @@ export default function SimpleJob({
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('datasets.multi_control_paths') && (
|
||||
<>
|
||||
<SelectInput
|
||||
label="Control Dataset 1"
|
||||
docKey="datasets.multi_control_paths"
|
||||
value={dataset.control_path_1 ?? ''}
|
||||
className="pt-2"
|
||||
onChange={value =>
|
||||
setJobConfig(
|
||||
value == '' ? null : value,
|
||||
`config.process[0].datasets[${i}].control_path_1`,
|
||||
)
|
||||
}
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Control Dataset 2"
|
||||
docKey="datasets.multi_control_paths"
|
||||
value={dataset.control_path_2 ?? ''}
|
||||
className="pt-2"
|
||||
onChange={value =>
|
||||
setJobConfig(
|
||||
value == '' ? null : value,
|
||||
`config.process[0].datasets[${i}].control_path_2`,
|
||||
)
|
||||
}
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Control Dataset 3"
|
||||
docKey="datasets.multi_control_paths"
|
||||
value={dataset.control_path_3 ?? ''}
|
||||
className="pt-2"
|
||||
onChange={value =>
|
||||
setJobConfig(
|
||||
value == '' ? null : value,
|
||||
`config.process[0].datasets[${i}].control_path_3`,
|
||||
)
|
||||
}
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<NumberInput
|
||||
label="LoRA Weight"
|
||||
value={dataset.network_weight}
|
||||
@@ -1062,30 +1058,43 @@ export default function SimpleJob({
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{modelArch?.additionalSections?.includes('datasets.multi_control_paths') && (
|
||||
<FormGroup label="Control Images" className="pt-2 ml-4">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-2 mt-2 mt-2">
|
||||
{['ctrl_img_1', 'ctrl_img_2', 'ctrl_img_3'].map((ctrlKey, ctrl_idx) => (
|
||||
<SampleControlImage
|
||||
key={ctrlKey}
|
||||
instruction={`Add Control Image ${ctrl_idx + 1}`}
|
||||
className=""
|
||||
src={sample[ctrlKey as keyof typeof sample] as string}
|
||||
onNewImageSelected={imagePath => {
|
||||
if (!imagePath) {
|
||||
let newSamples = objectCopy(jobConfig.config.process[0].sample.samples);
|
||||
delete newSamples[i][ctrlKey as keyof typeof sample];
|
||||
setJobConfig(newSamples, 'config.process[0].sample.samples');
|
||||
} else {
|
||||
setJobConfig(imagePath, `config.process[0].sample.samples[${i}].${ctrlKey}`);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</FormGroup>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
<div
|
||||
className="h-14 w-14 mt-2 ml-4 border border-gray-500 flex items-center justify-center rounded cursor-pointer hover:bg-gray-700 transition-colors"
|
||||
style={{
|
||||
backgroundImage: sample.ctrl_img
|
||||
? `url(${`/api/img/${encodeURIComponent(sample.ctrl_img)}`})`
|
||||
: 'none',
|
||||
backgroundSize: 'cover',
|
||||
backgroundPosition: 'center',
|
||||
marginBottom: '-1rem',
|
||||
}}
|
||||
onClick={() => {
|
||||
openAddImageModal(imagePath => {
|
||||
console.log('Selected image path:', imagePath);
|
||||
if (!imagePath) return;
|
||||
<SampleControlImage
|
||||
className="mt-6 ml-4"
|
||||
src={sample.ctrl_img}
|
||||
onNewImageSelected={imagePath => {
|
||||
if (!imagePath) {
|
||||
let newSamples = objectCopy(jobConfig.config.process[0].sample.samples);
|
||||
delete newSamples[i].ctrl_img;
|
||||
setJobConfig(newSamples, 'config.process[0].sample.samples');
|
||||
} else {
|
||||
setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`);
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{!sample.ctrl_img && (
|
||||
<div className="text-gray-400 text-xs text-center font-bold">Add Control Image</div>
|
||||
)}
|
||||
</div>
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="pb-4"></div>
|
||||
|
||||
@@ -2,7 +2,6 @@ import { JobConfig, DatasetConfig, SliderConfig } from '@/types';
|
||||
|
||||
export const defaultDatasetConfig: DatasetConfig = {
|
||||
folder_path: '/path/to/images/folder',
|
||||
control_path: null,
|
||||
mask_path: null,
|
||||
mask_min_value: 0.1,
|
||||
default_caption: '',
|
||||
|
||||
@@ -9,12 +9,15 @@ type DisableableSections =
|
||||
| 'network.conv'
|
||||
| 'trigger_word'
|
||||
| 'train.diff_output_preservation'
|
||||
| 'train.unload_text_encoder'
|
||||
| 'slider';
|
||||
|
||||
type AdditionalSections =
|
||||
| 'datasets.control_path'
|
||||
| 'datasets.multi_control_paths'
|
||||
| 'datasets.do_i2v'
|
||||
| 'sample.ctrl_img'
|
||||
| 'sample.multi_ctrl_imgs'
|
||||
| 'datasets.num_frames'
|
||||
| 'model.multistage'
|
||||
| 'model.low_vram';
|
||||
@@ -335,6 +338,28 @@ export const modelArchs: ModelArch[] = [
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'qwen_image_edit_plus',
|
||||
label: 'Qwen-Image-Edit-2509',
|
||||
group: 'instruction',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit-2509', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].train.unload_text_encoder': [false, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
},
|
||||
disableSections: ['network.conv', 'train.unload_text_encoder'],
|
||||
additionalSections: ['datasets.multi_control_paths', 'sample.multi_ctrl_imgs', 'model.low_vram'],
|
||||
accuracyRecoveryAdapters: {
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors',
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'hidream',
|
||||
label: 'HiDream',
|
||||
|
||||
105
ui/src/app/jobs/new/utils.ts
Normal file
105
ui/src/app/jobs/new/utils.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||
import { modelArchs, ModelArch } from './options';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
|
||||
export const handleModelArchChange = (
|
||||
currentArchName: string,
|
||||
newArchName: string,
|
||||
jobConfig: JobConfig,
|
||||
setJobConfig: (value: any, key: string) => void,
|
||||
) => {
|
||||
const currentArch = modelArchs.find(a => a.name === currentArchName);
|
||||
if (!currentArch || currentArch.name === newArchName) {
|
||||
return;
|
||||
}
|
||||
|
||||
// update the defaults when a model is selected
|
||||
const newArch = modelArchs.find(model => model.name === newArchName);
|
||||
|
||||
// update vram setting
|
||||
if (!newArch?.additionalSections?.includes('model.low_vram')) {
|
||||
setJobConfig(false, 'config.process[0].model.low_vram');
|
||||
}
|
||||
|
||||
// revert defaults from previous model
|
||||
for (const key in currentArch.defaults) {
|
||||
setJobConfig(currentArch.defaults[key][1], key);
|
||||
}
|
||||
|
||||
if (newArch?.defaults) {
|
||||
for (const key in newArch.defaults) {
|
||||
setJobConfig(newArch.defaults[key][0], key);
|
||||
}
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(newArchName, 'config.process[0].model.arch');
|
||||
|
||||
// update datasets
|
||||
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
||||
const hasMultiControlPaths = newArch?.additionalSections?.includes('datasets.multi_control_paths') || false;
|
||||
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
|
||||
const controls = newArch?.controls ?? [];
|
||||
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
||||
const newDataset = objectCopy(dataset);
|
||||
newDataset.controls = controls;
|
||||
if (hasMultiControlPaths) {
|
||||
// make sure the config has the multi control paths
|
||||
newDataset.control_path_1 = newDataset.control_path_1 || null;
|
||||
newDataset.control_path_2 = newDataset.control_path_2 || null;
|
||||
newDataset.control_path_3 = newDataset.control_path_3 || null;
|
||||
// if we previously had a single control path and now
|
||||
// we selected a multi control model
|
||||
if (newDataset.control_path && newDataset.control_path !== '') {
|
||||
// only set if not overwriting
|
||||
if (!newDataset.control_path_1) {
|
||||
newDataset.control_path_1 = newDataset.control_path;
|
||||
}
|
||||
}
|
||||
delete newDataset.control_path; // remove single control path
|
||||
} else if (hasControlPath) {
|
||||
newDataset.control_path = newDataset.control_path || null;
|
||||
if (newDataset.control_path_1 && newDataset.control_path_1 !== '') {
|
||||
newDataset.control_path = newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_1) {
|
||||
delete newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_2) {
|
||||
delete newDataset.control_path_2;
|
||||
}
|
||||
if (newDataset.control_path_3) {
|
||||
delete newDataset.control_path_3;
|
||||
}
|
||||
} else {
|
||||
// does not have control images
|
||||
if (newDataset.control_path) {
|
||||
delete newDataset.control_path;
|
||||
}
|
||||
if (newDataset.control_path_1) {
|
||||
delete newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_2) {
|
||||
delete newDataset.control_path_2;
|
||||
}
|
||||
if (newDataset.control_path_3) {
|
||||
delete newDataset.control_path_3;
|
||||
}
|
||||
}
|
||||
if (!hasNumFrames) {
|
||||
newDataset.num_frames = 1; // reset num_frames if not applicable
|
||||
}
|
||||
return newDataset;
|
||||
});
|
||||
setJobConfig(datasets, 'config.process[0].datasets');
|
||||
|
||||
// update samples
|
||||
const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false;
|
||||
const samples = jobConfig.config.process[0].sample.samples.map(sample => {
|
||||
const newSample = objectCopy(sample);
|
||||
if (!hasSampleCtrlImg) {
|
||||
delete newSample.ctrl_img; // remove ctrl_img if not applicable
|
||||
}
|
||||
return newSample;
|
||||
});
|
||||
setJobConfig(samples, 'config.process[0].sample.samples');
|
||||
};
|
||||
206
ui/src/components/SampleControlImage.tsx
Normal file
206
ui/src/components/SampleControlImage.tsx
Normal file
@@ -0,0 +1,206 @@
|
||||
'use client';
|
||||
|
||||
import React, { useCallback, useMemo, useRef, useState } from 'react';
|
||||
import classNames from 'classnames';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { FaUpload, FaImage, FaTimes } from 'react-icons/fa';
|
||||
import { apiClient } from '@/utils/api';
|
||||
import type { AxiosProgressEvent } from 'axios';
|
||||
|
||||
interface Props {
|
||||
src: string | null | undefined;
|
||||
className?: string;
|
||||
instruction?: string;
|
||||
onNewImageSelected: (imagePath: string | null) => void;
|
||||
}
|
||||
|
||||
export default function SampleControlImage({
|
||||
src,
|
||||
className,
|
||||
instruction = 'Add Control Image',
|
||||
onNewImageSelected,
|
||||
}: Props) {
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [uploadProgress, setUploadProgress] = useState(0);
|
||||
const [localPreview, setLocalPreview] = useState<string | null>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement | null>(null);
|
||||
|
||||
const backgroundUrl = useMemo(() => {
|
||||
if (localPreview) return localPreview;
|
||||
if (src) return `/api/img/${encodeURIComponent(src)}`;
|
||||
return null;
|
||||
}, [src, localPreview]);
|
||||
|
||||
const handleUpload = useCallback(
|
||||
async (file: File) => {
|
||||
if (!file) return;
|
||||
setIsUploading(true);
|
||||
setUploadProgress(0);
|
||||
|
||||
const objectUrl = URL.createObjectURL(file);
|
||||
setLocalPreview(objectUrl);
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('files', file);
|
||||
|
||||
try {
|
||||
const resp = await apiClient.post(`/api/img/upload`, formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
onUploadProgress: (evt: AxiosProgressEvent) => {
|
||||
const total = evt.total ?? 100;
|
||||
const loaded = evt.loaded ?? 0;
|
||||
setUploadProgress(Math.round((loaded * 100) / total));
|
||||
},
|
||||
timeout: 0,
|
||||
});
|
||||
|
||||
const uploaded = resp?.data?.files?.[0] ?? null;
|
||||
onNewImageSelected(uploaded);
|
||||
} catch (err) {
|
||||
console.error('Upload failed:', err);
|
||||
setLocalPreview(null);
|
||||
} finally {
|
||||
setIsUploading(false);
|
||||
setUploadProgress(0);
|
||||
URL.revokeObjectURL(objectUrl);
|
||||
if (fileInputRef.current) fileInputRef.current.value = '';
|
||||
}
|
||||
},
|
||||
[onNewImageSelected],
|
||||
);
|
||||
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: File[]) => {
|
||||
if (acceptedFiles.length === 0) return;
|
||||
handleUpload(acceptedFiles[0]);
|
||||
},
|
||||
[handleUpload],
|
||||
);
|
||||
|
||||
const clearImage = useCallback(
|
||||
(e?: React.MouseEvent) => {
|
||||
console.log('clearImage');
|
||||
if (e) {
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
}
|
||||
setLocalPreview(null);
|
||||
onNewImageSelected(null);
|
||||
if (fileInputRef.current) fileInputRef.current.value = '';
|
||||
},
|
||||
[onNewImageSelected],
|
||||
);
|
||||
|
||||
// Drag & drop only; click handled via our own hidden input
|
||||
const { getRootProps, isDragActive } = useDropzone({
|
||||
onDrop,
|
||||
accept: { 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'] },
|
||||
multiple: false,
|
||||
noClick: true,
|
||||
noKeyboard: true,
|
||||
});
|
||||
|
||||
const rootProps = getRootProps();
|
||||
|
||||
return (
|
||||
<div
|
||||
{...rootProps}
|
||||
className={classNames(
|
||||
'group relative flex items-center justify-center rounded-xl cursor-pointer ring-1 ring-inset',
|
||||
'transition-all duration-200 select-none overflow-hidden text-center',
|
||||
'h-20 w-20',
|
||||
backgroundUrl ? 'bg-gray-800 ring-gray-700' : 'bg-gradient-to-b from-gray-800 to-gray-900 ring-gray-700',
|
||||
isDragActive ? 'outline outline-2 outline-blue-500' : 'hover:ring-gray-600',
|
||||
className,
|
||||
)}
|
||||
style={
|
||||
backgroundUrl
|
||||
? {
|
||||
backgroundImage: `url("${backgroundUrl}")`,
|
||||
backgroundSize: 'cover',
|
||||
backgroundPosition: 'center',
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onClick={() => !isUploading && fileInputRef.current?.click()}
|
||||
>
|
||||
{/* Hidden input for click-to-open */}
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept="image/*"
|
||||
className="hidden"
|
||||
onChange={e => {
|
||||
const file = e.currentTarget.files?.[0];
|
||||
if (file) handleUpload(file);
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* Empty state — centered */}
|
||||
{!backgroundUrl && (
|
||||
<div className="flex flex-col items-center justify-center text-gray-300 text-center">
|
||||
<FaImage className="opacity-80" />
|
||||
<div className="mt-1 text-[10px] font-semibold tracking-wide opacity-80">{instruction}</div>
|
||||
<div className="mt-0.5 text-[9px] opacity-60">Click or drop</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Existing image overlays */}
|
||||
{backgroundUrl && !isUploading && (
|
||||
<>
|
||||
<div
|
||||
className={classNames(
|
||||
'pointer-events-none absolute inset-0 flex items-center justify-center',
|
||||
'bg-black/0 group-hover:bg-black/20',
|
||||
isDragActive && 'bg-black/35',
|
||||
'transition-colors',
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={classNames(
|
||||
'inline-flex items-center gap-1 rounded-md px-2 py-1',
|
||||
'text-[10px] font-semibold',
|
||||
'bg-black/45 text-white/90 backdrop-blur-sm',
|
||||
'opacity-0 group-hover:opacity-100 transition-opacity',
|
||||
)}
|
||||
>
|
||||
<FaUpload className="text-[10px]" />
|
||||
<span>Replace</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Clear (X) button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={clearImage}
|
||||
title="Clear image"
|
||||
aria-label="Clear image"
|
||||
className={classNames(
|
||||
'absolute right-1.5 top-1.5 z-10 inline-flex items-center justify-center',
|
||||
'h-5 w-5 rounded-md bg-black/55 text-white/90',
|
||||
'opacity-0 group-hover:opacity-100 transition-opacity',
|
||||
'hover:bg-black/70',
|
||||
)}
|
||||
>
|
||||
<FaTimes className="text-[10px]" />
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Uploading overlay */}
|
||||
{isUploading && (
|
||||
<div className="absolute inset-0 flex flex-col items-center justify-center bg-black/60 backdrop-blur-[1px] text-center">
|
||||
<div className="w-4/5 max-w-40">
|
||||
<div className="h-1.5 w-full rounded-full bg-white/15">
|
||||
<div
|
||||
className="h-1.5 rounded-full bg-white/80 transition-[width]"
|
||||
style={{ width: `${uploadProgress}%` }}
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-1 text-[10px] font-medium text-white/90">Uploading… {uploadProgress}%</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -53,10 +53,26 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
},
|
||||
'datasets.control_path': {
|
||||
title: 'Control Dataset',
|
||||
description: (
|
||||
<>
|
||||
The control dataset needs to have files that match the filenames of your training dataset. They should be
|
||||
matching file pairs. These images are fed as control/input images during training. The control images will be
|
||||
resized to match the training images.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.multi_control_paths': {
|
||||
title: 'Multi Control Dataset',
|
||||
description: (
|
||||
<>
|
||||
The control dataset needs to have files that match the filenames of your training dataset. They should be
|
||||
matching file pairs. These images are fed as control/input images during training.
|
||||
<br />
|
||||
<br />
|
||||
For multi control datasets, the controls will all be applied in the order they are listed. If the model does not
|
||||
require the images to be the same aspect ratios, such as with Qwen/Qwen-Image-Edit-2509, then the control images
|
||||
do not need to match the aspect size or aspect ratio of the target image and they will be automatically resized to
|
||||
the ideal resolutions for the model / target images.
|
||||
</>
|
||||
),
|
||||
},
|
||||
|
||||
@@ -83,12 +83,15 @@ export interface DatasetConfig {
|
||||
cache_latents_to_disk?: boolean;
|
||||
resolution: number[];
|
||||
controls: string[];
|
||||
control_path: string | null;
|
||||
control_path?: string | null;
|
||||
num_frames: number;
|
||||
shrink_video_to_frames: boolean;
|
||||
do_i2v: boolean;
|
||||
flip_x: boolean;
|
||||
flip_y: boolean;
|
||||
control_path_1?: string | null;
|
||||
control_path_2?: string | null;
|
||||
control_path_3?: string | null;
|
||||
}
|
||||
|
||||
export interface EMAConfig {
|
||||
@@ -155,6 +158,9 @@ export interface SampleItem {
|
||||
ctrl_img?: string | null;
|
||||
ctrl_idx?: number;
|
||||
network_multiplier?: number;
|
||||
ctrl_img_1?: string | null;
|
||||
ctrl_img_2?: string | null;
|
||||
ctrl_img_3?: string | null;
|
||||
}
|
||||
|
||||
export interface SampleConfig {
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.5.10"
|
||||
VERSION = "0.6.0"
|
||||
Reference in New Issue
Block a user