mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for new concept slider training script to CLI and UI
This commit is contained in:
277
extensions_built_in/concept_slider/ConceptSliderTrainer.py
Normal file
277
extensions_built_in/concept_slider/ConceptSliderTrainer.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from extensions_built_in.sd_trainer.DiffusionTrainer import DiffusionTrainer
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
class ConceptSliderTrainerConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.guidance_strength: float = kwargs.get("guidance_strength", 3.0)
|
||||
self.anchor_strength: float = kwargs.get("anchor_strength", 1.0)
|
||||
self.positive_prompt: str = kwargs.get("positive_prompt", "")
|
||||
self.negative_prompt: str = kwargs.get("negative_prompt", "")
|
||||
self.target_class: str = kwargs.get("target_class", "")
|
||||
self.anchor_class: Optional[str] = kwargs.get("anchor_class", None)
|
||||
|
||||
|
||||
class ConceptSliderTrainer(DiffusionTrainer):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
self.do_guided_loss = True
|
||||
|
||||
self.slider: ConceptSliderTrainerConfig = ConceptSliderTrainerConfig(
|
||||
**self.config.get("slider", {})
|
||||
)
|
||||
|
||||
self.positive_prompt = self.slider.positive_prompt
|
||||
self.positive_prompt_embeds: Optional[PromptEmbeds] = None
|
||||
self.negative_prompt = self.slider.negative_prompt
|
||||
self.negative_prompt_embeds: Optional[PromptEmbeds] = None
|
||||
self.target_class = self.slider.target_class
|
||||
self.target_class_embeds: Optional[PromptEmbeds] = None
|
||||
self.anchor_class = self.slider.anchor_class
|
||||
self.anchor_class_embeds: Optional[PromptEmbeds] = None
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
# do this before calling parent as it unloads the text encoder if requested
|
||||
if self.is_caching_text_embeddings:
|
||||
# make sure model is on cpu for this part so we don't oom.
|
||||
self.sd.unet.to("cpu")
|
||||
|
||||
# cache unconditional embeds (blank prompt)
|
||||
with torch.no_grad():
|
||||
self.positive_prompt_embeds = (
|
||||
self.sd.encode_prompt(
|
||||
[self.positive_prompt],
|
||||
)
|
||||
.to(self.device_torch, dtype=self.sd.torch_dtype)
|
||||
.detach()
|
||||
)
|
||||
|
||||
self.target_class_embeds = (
|
||||
self.sd.encode_prompt(
|
||||
[self.target_class],
|
||||
)
|
||||
.to(self.device_torch, dtype=self.sd.torch_dtype)
|
||||
.detach()
|
||||
)
|
||||
|
||||
self.negative_prompt_embeds = (
|
||||
self.sd.encode_prompt(
|
||||
[self.negative_prompt],
|
||||
)
|
||||
.to(self.device_torch, dtype=self.sd.torch_dtype)
|
||||
.detach()
|
||||
)
|
||||
|
||||
if self.anchor_class is not None:
|
||||
self.anchor_class_embeds = (
|
||||
self.sd.encode_prompt(
|
||||
[self.anchor_class],
|
||||
)
|
||||
.to(self.device_torch, dtype=self.sd.torch_dtype)
|
||||
.detach()
|
||||
)
|
||||
|
||||
# call parent
|
||||
super().hook_before_train_loop()
|
||||
|
||||
def get_guided_loss(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: "DataLoaderBatchDTO",
|
||||
noise: torch.Tensor,
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# todo for embeddings, we need to run without trigger words
|
||||
was_unet_training = self.sd.unet.training
|
||||
was_network_active = False
|
||||
if self.network is not None:
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
|
||||
|
||||
# do out prior preds first
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
self.sd.unet.eval()
|
||||
noisy_latents = noisy_latents.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
batch_size = noisy_latents.shape[0]
|
||||
|
||||
positive_embeds = concat_prompt_embeds(
|
||||
[self.positive_prompt_embeds] * batch_size
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
target_class_embeds = concat_prompt_embeds(
|
||||
[self.target_class_embeds] * batch_size
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
negative_embeds = concat_prompt_embeds(
|
||||
[self.negative_prompt_embeds] * batch_size
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.anchor_class_embeds is not None:
|
||||
anchor_embeds = concat_prompt_embeds(
|
||||
[self.anchor_class_embeds] * batch_size
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.anchor_class_embeds is not None:
|
||||
# if we have an anchor, do it
|
||||
combo_embeds = concat_prompt_embeds(
|
||||
[
|
||||
positive_embeds,
|
||||
target_class_embeds,
|
||||
negative_embeds,
|
||||
anchor_embeds,
|
||||
]
|
||||
)
|
||||
num_embeds = 4
|
||||
else:
|
||||
combo_embeds = concat_prompt_embeds(
|
||||
[positive_embeds, target_class_embeds, negative_embeds]
|
||||
)
|
||||
num_embeds = 3
|
||||
|
||||
# do them in one batch, VRAM should handle it since we are no grad
|
||||
combo_pred = self.sd.predict_noise(
|
||||
latents=torch.cat([noisy_latents] * num_embeds, dim=0),
|
||||
conditional_embeddings=combo_embeds,
|
||||
timestep=torch.cat([timesteps] * num_embeds, dim=0),
|
||||
guidance_scale=1.0,
|
||||
guidance_embedding_scale=1.0,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if self.anchor_class_embeds is not None:
|
||||
positive_pred, neutral_pred, negative_pred, anchor_target = (
|
||||
combo_pred.chunk(4, dim=0)
|
||||
)
|
||||
else:
|
||||
anchor_target = None
|
||||
positive_pred, neutral_pred, negative_pred = combo_pred.chunk(3, dim=0)
|
||||
|
||||
# calculate the targets
|
||||
guidance_scale = self.slider.guidance_strength
|
||||
|
||||
# enhance_positive_target = neutral_pred + guidance_scale * (
|
||||
# positive_pred - negative_pred
|
||||
# )
|
||||
# enhance_negative_target = neutral_pred + guidance_scale * (
|
||||
# negative_pred - positive_pred
|
||||
# )
|
||||
# erase_negative_target = neutral_pred - guidance_scale * (
|
||||
# negative_pred - positive_pred
|
||||
# )
|
||||
# erase_positive_target = neutral_pred - guidance_scale * (
|
||||
# positive_pred - negative_pred
|
||||
# )
|
||||
|
||||
positive = (positive_pred - neutral_pred) - (negative_pred - neutral_pred)
|
||||
negative = (negative_pred - neutral_pred) - (positive_pred - neutral_pred)
|
||||
|
||||
enhance_positive_target = neutral_pred + guidance_scale * positive
|
||||
enhance_negative_target = neutral_pred + guidance_scale * negative
|
||||
erase_negative_target = neutral_pred - guidance_scale * negative
|
||||
erase_positive_target = neutral_pred - guidance_scale * positive
|
||||
|
||||
if was_unet_training:
|
||||
self.sd.unet.train()
|
||||
|
||||
# restore network
|
||||
if self.network is not None:
|
||||
self.network.is_active = was_network_active
|
||||
|
||||
if self.anchor_class_embeds is not None:
|
||||
# do a grad inference with our target prompt
|
||||
embeds = concat_prompt_embeds([target_class_embeds, anchor_embeds]).to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
|
||||
noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=0).to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
else:
|
||||
embeds = target_class_embeds.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# do positive first
|
||||
self.network.set_multiplier(1.0)
|
||||
pred = self.sd.predict_noise(
|
||||
latents=noisy_latents,
|
||||
conditional_embeddings=embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
guidance_embedding_scale=1.0,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if self.anchor_class_embeds is not None:
|
||||
class_pred, anchor_pred = pred.chunk(2, dim=0)
|
||||
else:
|
||||
class_pred = pred
|
||||
anchor_pred = None
|
||||
|
||||
# enhance positive loss
|
||||
enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_positive_target)
|
||||
|
||||
erase_loss = torch.nn.functional.mse_loss(class_pred, erase_negative_target)
|
||||
|
||||
if anchor_target is None:
|
||||
anchor_loss = torch.zeros_like(erase_loss)
|
||||
else:
|
||||
anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target)
|
||||
|
||||
anchor_loss = anchor_loss * self.slider.anchor_strength
|
||||
|
||||
# send backward now because gradient checkpointing needs network polarity intact
|
||||
total_pos_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0
|
||||
total_pos_loss.backward()
|
||||
total_pos_loss = total_pos_loss.detach()
|
||||
|
||||
# now do negative
|
||||
self.network.set_multiplier(-1.0)
|
||||
pred = self.sd.predict_noise(
|
||||
latents=noisy_latents,
|
||||
conditional_embeddings=embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
guidance_embedding_scale=1.0,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if self.anchor_class_embeds is not None:
|
||||
class_pred, anchor_pred = pred.chunk(2, dim=0)
|
||||
else:
|
||||
class_pred = pred
|
||||
anchor_pred = None
|
||||
|
||||
# enhance negative loss
|
||||
enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_negative_target)
|
||||
erase_loss = torch.nn.functional.mse_loss(class_pred, erase_positive_target)
|
||||
|
||||
if anchor_target is None:
|
||||
anchor_loss = torch.zeros_like(erase_loss)
|
||||
else:
|
||||
anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target)
|
||||
anchor_loss = anchor_loss * self.slider.anchor_strength
|
||||
total_neg_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0
|
||||
total_neg_loss.backward()
|
||||
total_neg_loss = total_neg_loss.detach()
|
||||
|
||||
self.network.set_multiplier(1.0)
|
||||
|
||||
total_loss = (total_pos_loss + total_neg_loss) / 2.0
|
||||
|
||||
# add a grad so backward works right
|
||||
total_loss.requires_grad_(True)
|
||||
return total_loss
|
||||
26
extensions_built_in/concept_slider/__init__.py
Normal file
26
extensions_built_in/concept_slider/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# This is an example extension for custom training. It is great for experimenting with new ideas.
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class ConceptSliderTrainerTrainer(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "concept_slider"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Concept Slider Trainer"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .ConceptSliderTrainer import ConceptSliderTrainer
|
||||
|
||||
return ConceptSliderTrainer
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
ConceptSliderTrainerTrainer
|
||||
]
|
||||
297
extensions_built_in/sd_trainer/DiffusionTrainer.py
Normal file
297
extensions_built_in/sd_trainer/DiffusionTrainer.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sqlite3
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
|
||||
from typing import Literal, Optional
|
||||
import threading
|
||||
import time
|
||||
import signal
|
||||
|
||||
AITK_Status = Literal["running", "stopped", "error", "completed"]
|
||||
|
||||
|
||||
class DiffusionTrainer(SDTrainer):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super(DiffusionTrainer, self).__init__(process_id, job, config, **kwargs)
|
||||
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
|
||||
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
||||
self.job_id = self.job_id.strip() if self.job_id is not None else None
|
||||
self.is_ui_trainer = True
|
||||
if not os.path.exists(self.sqlite_db_path):
|
||||
self.is_ui_trainer = False
|
||||
else:
|
||||
print(f"Using SQLite database at {self.sqlite_db_path}")
|
||||
if self.job_id is None:
|
||||
self.is_ui_trainer = False
|
||||
else:
|
||||
print(f"Job ID: \"{self.job_id}\"")
|
||||
|
||||
if self.is_ui_trainer:
|
||||
self.is_stopping = False
|
||||
# Create a thread pool for database operations
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
# Track all async tasks
|
||||
self._async_tasks = []
|
||||
# Initialize the status
|
||||
self._run_async_operation(self._update_status("running", "Starting"))
|
||||
self._stop_watcher_started = False
|
||||
# self.start_stop_watcher(interval_sec=2.0)
|
||||
|
||||
def start_stop_watcher(self, interval_sec: float = 5.0):
|
||||
"""
|
||||
Start a daemon thread that periodically checks should_stop()
|
||||
and terminates the process immediately when triggered.
|
||||
"""
|
||||
if not self.is_ui_trainer:
|
||||
return
|
||||
if getattr(self, "_stop_watcher_started", False):
|
||||
return
|
||||
self._stop_watcher_started = True
|
||||
t = threading.Thread(
|
||||
target=self._stop_watcher_thread, args=(interval_sec,), daemon=True
|
||||
)
|
||||
t.start()
|
||||
|
||||
def _stop_watcher_thread(self, interval_sec: float):
|
||||
while True:
|
||||
try:
|
||||
if self.should_stop():
|
||||
# Mark and update status (non-blocking; uses existing infra)
|
||||
self.is_stopping = True
|
||||
self._run_async_operation(
|
||||
self._update_status("stopped", "Job stopped (remote)")
|
||||
)
|
||||
# Best-effort flush pending async ops
|
||||
try:
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
except RuntimeError:
|
||||
pass
|
||||
# Try to stop DB thread pool quickly
|
||||
try:
|
||||
self.thread_pool.shutdown(wait=False, cancel_futures=True)
|
||||
except TypeError:
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
print("")
|
||||
print("****************************************************")
|
||||
print(" Stop signal received; terminating process. ")
|
||||
print("****************************************************")
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
time.sleep(interval_sec)
|
||||
except Exception:
|
||||
time.sleep(interval_sec)
|
||||
|
||||
def _run_async_operation(self, coro):
|
||||
"""Helper method to run an async coroutine and track the task."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Create a task and track it
|
||||
if loop.is_running():
|
||||
task = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
self._async_tasks.append(asyncio.wrap_future(task))
|
||||
else:
|
||||
task = loop.create_task(coro)
|
||||
self._async_tasks.append(task)
|
||||
loop.run_until_complete(task)
|
||||
|
||||
async def _execute_db_operation(self, operation_func):
|
||||
"""Execute a database operation in a separate thread to avoid blocking."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.thread_pool, operation_func)
|
||||
|
||||
def _db_connect(self):
|
||||
"""Create a new connection for each operation to avoid locking."""
|
||||
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
|
||||
conn.isolation_level = None # Enable autocommit mode
|
||||
return conn
|
||||
|
||||
def should_stop(self):
|
||||
if not self.is_ui_trainer:
|
||||
return False
|
||||
def _check_stop():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT stop FROM Job WHERE id = ?", (self.job_id,))
|
||||
stop = cursor.fetchone()
|
||||
return False if stop is None else stop[0] == 1
|
||||
|
||||
return _check_stop()
|
||||
|
||||
def maybe_stop(self):
|
||||
if not self.is_ui_trainer:
|
||||
return
|
||||
if self.should_stop():
|
||||
self._run_async_operation(
|
||||
self._update_status("stopped", "Job stopped"))
|
||||
self.is_stopping = True
|
||||
raise Exception("Job stopped")
|
||||
|
||||
async def _update_key(self, key, value):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
# Convert the value to string if it's not already
|
||||
if isinstance(value, str):
|
||||
value_to_insert = value
|
||||
else:
|
||||
value_to_insert = str(value)
|
||||
|
||||
# Use parameterized query for both the column name and value
|
||||
update_query = f"UPDATE Job SET {key} = ? WHERE id = ?"
|
||||
cursor.execute(
|
||||
update_query, (value_to_insert, self.job_id))
|
||||
finally:
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_step(self):
|
||||
"""Non-blocking update of the step count."""
|
||||
if self.accelerator.is_main_process and self.is_ui_trainer:
|
||||
self._run_async_operation(self._update_key("step", self.step_num))
|
||||
|
||||
def update_db_key(self, key, value):
|
||||
"""Non-blocking update a key in the database."""
|
||||
if self.accelerator.is_main_process and self.is_ui_trainer:
|
||||
self._run_async_operation(self._update_key(key, value))
|
||||
|
||||
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
if not self.accelerator.is_main_process or not self.is_ui_trainer:
|
||||
return
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
if info is not None:
|
||||
cursor.execute(
|
||||
"UPDATE Job SET status = ?, info = ? WHERE id = ?",
|
||||
(status, info, self.job_id)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"UPDATE Job SET status = ? WHERE id = ?",
|
||||
(status, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
"""Non-blocking update of status."""
|
||||
if self.accelerator.is_main_process and self.is_ui_trainer:
|
||||
self._run_async_operation(self._update_status(status, info))
|
||||
|
||||
async def wait_for_all_async(self):
|
||||
"""Wait for all tracked async operations to complete."""
|
||||
if not self._async_tasks:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self._async_tasks)
|
||||
except Exception as e:
|
||||
pass
|
||||
finally:
|
||||
# Clear the task list after completion
|
||||
self._async_tasks.clear()
|
||||
|
||||
def on_error(self, e: Exception):
|
||||
super(DiffusionTrainer, self).on_error(e)
|
||||
if self.is_ui_trainer:
|
||||
if self.accelerator.is_main_process and not self.is_stopping:
|
||||
self.update_status("error", str(e))
|
||||
self.update_db_key("step", self.last_save_step)
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def handle_timing_print_hook(self, timing_dict):
|
||||
if "train_loop" not in timing_dict:
|
||||
print("train_loop not found in timing_dict", timing_dict)
|
||||
return
|
||||
seconds_per_iter = timing_dict["train_loop"]
|
||||
# determine iter/sec or sec/iter
|
||||
if seconds_per_iter < 1:
|
||||
iters_per_sec = 1 / seconds_per_iter
|
||||
self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec")
|
||||
else:
|
||||
self.update_db_key(
|
||||
"speed_string", f"{seconds_per_iter:.2f} sec/iter")
|
||||
|
||||
def done_hook(self):
|
||||
super(DiffusionTrainer, self).done_hook()
|
||||
if self.is_ui_trainer:
|
||||
self.update_status("completed", "Training completed")
|
||||
# Wait for all async operations to finish before shutting down
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def end_step_hook(self):
|
||||
super(DiffusionTrainer, self).end_step_hook()
|
||||
if self.is_ui_trainer:
|
||||
self.update_step()
|
||||
self.maybe_stop()
|
||||
|
||||
def hook_before_model_load(self):
|
||||
super().hook_before_model_load()
|
||||
if self.is_ui_trainer:
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Loading model")
|
||||
|
||||
def before_dataset_load(self):
|
||||
super().before_dataset_load()
|
||||
if self.is_ui_trainer:
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Loading dataset")
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.is_ui_trainer:
|
||||
self.maybe_stop()
|
||||
self.update_step()
|
||||
self.update_status("running", "Training")
|
||||
self.timer.add_after_print_hook(self.handle_timing_print_hook)
|
||||
|
||||
def status_update_hook_func(self, string):
|
||||
self.update_status("running", string)
|
||||
|
||||
def hook_after_sd_init_before_load(self):
|
||||
super().hook_after_sd_init_before_load()
|
||||
if self.is_ui_trainer:
|
||||
self.maybe_stop()
|
||||
self.sd.add_status_update_hook(self.status_update_hook_func)
|
||||
|
||||
def sample_step_hook(self, img_num, total_imgs):
|
||||
super().sample_step_hook(img_num, total_imgs)
|
||||
if self.is_ui_trainer:
|
||||
self.maybe_stop()
|
||||
self.update_status(
|
||||
"running", f"Generating images - {img_num + 1}/{total_imgs}")
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
self.maybe_stop()
|
||||
total_imgs = len(self.sample_config.prompts)
|
||||
self.update_status("running", f"Generating images - 0/{total_imgs}")
|
||||
super().sample(step, is_first)
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
|
||||
def save(self, step=None):
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Saving model")
|
||||
super().save(step)
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
@@ -16,8 +16,10 @@ class SDTrainerExtension(Extension):
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .SDTrainer import SDTrainer
|
||||
|
||||
return SDTrainer
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class UITrainerExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
@@ -32,9 +34,28 @@ class UITrainerExtension(Extension):
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .UITrainer import UITrainer
|
||||
|
||||
return UITrainer
|
||||
|
||||
|
||||
# This is a universal trainer that can be from ui or api
|
||||
class DiffusionTrainerExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "diffusion_trainer"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Diffusion Trainer"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .DiffusionTrainer import DiffusionTrainer
|
||||
|
||||
return DiffusionTrainer
|
||||
|
||||
|
||||
# for backwards compatability
|
||||
class TextualInversionTrainer(SDTrainerExtension):
|
||||
uid = "textual_inversion_trainer"
|
||||
@@ -42,5 +63,8 @@ class TextualInversionTrainer(SDTrainerExtension):
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
SDTrainerExtension, TextualInversionTrainer, UITrainerExtension
|
||||
SDTrainerExtension,
|
||||
TextualInversionTrainer,
|
||||
UITrainerExtension,
|
||||
DiffusionTrainerExtension,
|
||||
]
|
||||
|
||||
@@ -57,6 +57,13 @@ class SampleItem:
|
||||
self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None)
|
||||
self.ctrl_idx: int = kwargs.get('ctrl_idx', 0)
|
||||
self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier)
|
||||
# convert to a number if it is a string
|
||||
if isinstance(self.network_multiplier, str):
|
||||
try:
|
||||
self.network_multiplier = float(self.network_multiplier)
|
||||
except:
|
||||
print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0")
|
||||
self.network_multiplier = 1.0
|
||||
|
||||
|
||||
class SampleConfig:
|
||||
|
||||
@@ -108,7 +108,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
|
||||
|
||||
// We have to ensure certain things are always set
|
||||
try {
|
||||
parsed.config.process[0].type = 'ui_trainer';
|
||||
// parsed.config.process[0].type = 'ui_trainer';
|
||||
parsed.config.process[0].sqlite_db_path = './aitk_db.db';
|
||||
parsed.config.process[0].training_folder = settings.TRAINING_FOLDER;
|
||||
parsed.config.process[0].device = 'cuda';
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
'use client';
|
||||
import { useMemo } from 'react';
|
||||
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
|
||||
import {
|
||||
modelArchs,
|
||||
ModelArch,
|
||||
groupedModelOptions,
|
||||
quantizationOptions,
|
||||
defaultQtype,
|
||||
jobTypeOptions,
|
||||
} from './options';
|
||||
import { defaultDatasetConfig } from './jobConfig';
|
||||
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
@@ -8,7 +15,7 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp
|
||||
import Card from '@/components/Card';
|
||||
import { X } from 'lucide-react';
|
||||
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
|
||||
import {FlipHorizontal2, FlipVertical2} from "lucide-react"
|
||||
import { FlipHorizontal2, FlipVertical2 } from 'lucide-react';
|
||||
|
||||
type Props = {
|
||||
jobConfig: JobConfig;
|
||||
@@ -39,6 +46,21 @@ export default function SimpleJob({
|
||||
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
||||
}, [jobConfig.config.process[0].model.arch]);
|
||||
|
||||
const jobType = useMemo(() => {
|
||||
return jobTypeOptions.find(j => j.value === jobConfig.config.process[0].type);
|
||||
}, [jobConfig.config.process[0].type]);
|
||||
|
||||
const disableSections = useMemo(() => {
|
||||
let sections: string[] = [];
|
||||
if (modelArch?.disableSections) {
|
||||
sections = sections.concat(modelArch.disableSections);
|
||||
}
|
||||
if (jobType?.disableSections) {
|
||||
sections = sections.concat(jobType.disableSections);
|
||||
}
|
||||
return sections;
|
||||
}, [modelArch, jobType]);
|
||||
|
||||
const isVideoModel = !!(modelArch?.group === 'video');
|
||||
|
||||
const numTopCards = useMemo(() => {
|
||||
@@ -46,12 +68,14 @@ export default function SimpleJob({
|
||||
if (modelArch?.additionalSections?.includes('model.multistage')) {
|
||||
count += 1; // add multistage card
|
||||
}
|
||||
if (!modelArch?.disableSections?.includes('model.quantize')) {
|
||||
if (!disableSections.includes('model.quantize')) {
|
||||
count += 1; // add quantization card
|
||||
}
|
||||
if (!disableSections.includes('slider')) {
|
||||
count += 1; // add slider card
|
||||
}
|
||||
return count;
|
||||
|
||||
}, [modelArch]);
|
||||
}, [modelArch, disableSections]);
|
||||
|
||||
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
|
||||
@@ -62,6 +86,20 @@ export default function SimpleJob({
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
|
||||
}
|
||||
|
||||
const numTrainingCols = useMemo(() => {
|
||||
let count = 4;
|
||||
if (!disableSections.includes('train.diff_output_preservation')) {
|
||||
count += 1;
|
||||
}
|
||||
return count;
|
||||
}, [disableSections]);
|
||||
|
||||
let trainingBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6';
|
||||
|
||||
if (numTrainingCols == 5) {
|
||||
trainingBarClass = 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6';
|
||||
}
|
||||
|
||||
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
|
||||
const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
|
||||
if (!hasARA) {
|
||||
@@ -78,7 +116,7 @@ export default function SimpleJob({
|
||||
let ARAs: SelectOption[] = [];
|
||||
if (modelArch.accuracyRecoveryAdapters) {
|
||||
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
|
||||
ARAs.push({ value, label });
|
||||
ARAs.push({ value, label });
|
||||
}
|
||||
}
|
||||
if (ARAs.length > 0) {
|
||||
@@ -124,19 +162,21 @@ export default function SimpleJob({
|
||||
onChange={value => setGpuIDs(value)}
|
||||
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
||||
/>
|
||||
<TextInput
|
||||
label="Trigger Word"
|
||||
value={jobConfig.config.process[0].trigger_word || ''}
|
||||
docKey="config.process[0].trigger_word"
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].trigger_word');
|
||||
}}
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
{disableSections.includes('trigger_word') ? null : (
|
||||
<TextInput
|
||||
label="Trigger Word"
|
||||
value={jobConfig.config.process[0].trigger_word || ''}
|
||||
docKey="config.process[0].trigger_word"
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].trigger_word');
|
||||
}}
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
@@ -223,7 +263,7 @@ export default function SimpleJob({
|
||||
</FormGroup>
|
||||
)}
|
||||
</Card>
|
||||
{modelArch?.disableSections?.includes('model.quantize') ? null : (
|
||||
{disableSections.includes('model.quantize') ? null : (
|
||||
<Card title="Quantization">
|
||||
<SelectInput
|
||||
label="Transformer"
|
||||
@@ -270,14 +310,14 @@ export default function SimpleJob({
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="Switch Every"
|
||||
value={jobConfig.config.process[0].train.switch_boundary_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
|
||||
placeholder="eg. 1"
|
||||
docKey={'train.switch_boundary_every'}
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
label="Switch Every"
|
||||
value={jobConfig.config.process[0].train.switch_boundary_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
|
||||
placeholder="eg. 1"
|
||||
docKey={'train.switch_boundary_every'}
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Target">
|
||||
@@ -319,7 +359,7 @@ export default function SimpleJob({
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
{modelArch?.disableSections?.includes('network.conv') ? null : (
|
||||
{disableSections.includes('network.conv') ? null : (
|
||||
<NumberInput
|
||||
label="Conv Rank"
|
||||
value={jobConfig.config.process[0].network.conv}
|
||||
@@ -336,6 +376,38 @@ export default function SimpleJob({
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
{!disableSections.includes('slider') && (
|
||||
<Card title="Slider">
|
||||
<TextInput
|
||||
label="Target Class"
|
||||
className=""
|
||||
value={jobConfig.config.process[0].slider?.target_class ?? ''}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].slider.target_class')}
|
||||
placeholder="eg. person"
|
||||
/>
|
||||
<TextInput
|
||||
label="Positive Prompt"
|
||||
className=""
|
||||
value={jobConfig.config.process[0].slider?.positive_prompt ?? ''}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].slider.positive_prompt')}
|
||||
placeholder="eg. person who is happy"
|
||||
/>
|
||||
<TextInput
|
||||
label="Negative Prompt"
|
||||
className=""
|
||||
value={jobConfig.config.process[0].slider?.negative_prompt ?? ''}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].slider.negative_prompt')}
|
||||
placeholder="eg. person who is sad"
|
||||
/>
|
||||
<TextInput
|
||||
label="Anchor Class"
|
||||
className=""
|
||||
value={jobConfig.config.process[0].slider?.anchor_class ?? ''}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].slider.anchor_class')}
|
||||
placeholder=""
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Save">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
@@ -367,7 +439,7 @@ export default function SimpleJob({
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Training">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
|
||||
<div className={trainingBarClass}>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Batch Size"
|
||||
@@ -426,11 +498,11 @@ export default function SimpleJob({
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
{modelArch?.disableSections?.includes('train.timestep_type') ? null : (
|
||||
{disableSections.includes('train.timestep_type') ? null : (
|
||||
<SelectInput
|
||||
label="Timestep Type"
|
||||
value={jobConfig.config.process[0].train.timestep_type}
|
||||
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
|
||||
disabled={disableSections.includes('train.timestep_type') || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
||||
options={[
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
@@ -508,33 +580,39 @@ export default function SimpleJob({
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Regularization">
|
||||
<Checkbox
|
||||
label="Differtial Output Preservation"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
{jobConfig.config.process[0].train.diff_output_preservation && (
|
||||
{disableSections.includes('train.diff_output_preservation') ? null : (
|
||||
<>
|
||||
<NumberInput
|
||||
label="DOP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
<FormGroup label="Regularization">
|
||||
<Checkbox
|
||||
label="Differential Output Preservation"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
{jobConfig.config.process[0].train.diff_output_preservation && (
|
||||
<>
|
||||
<NumberInput
|
||||
label="DOP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
|
||||
}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
@@ -641,12 +719,20 @@ export default function SimpleJob({
|
||||
</FormGroup>
|
||||
<FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
|
||||
<Checkbox
|
||||
label={<>Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" /></>}
|
||||
label={
|
||||
<>
|
||||
Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" />
|
||||
</>
|
||||
}
|
||||
checked={dataset.flip_x || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)}
|
||||
/>
|
||||
<Checkbox
|
||||
label={<>Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" /></>}
|
||||
label={
|
||||
<>
|
||||
Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" />
|
||||
</>
|
||||
}
|
||||
checked={dataset.flip_y || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)}
|
||||
/>
|
||||
@@ -812,7 +898,7 @@ export default function SimpleJob({
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.skip_first_sample');
|
||||
// cannot do both, so disable the other
|
||||
if (value){
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.force_first_sample');
|
||||
}
|
||||
}}
|
||||
@@ -827,7 +913,7 @@ export default function SimpleJob({
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.force_first_sample');
|
||||
// cannot do both, so disable the other
|
||||
if (value){
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.skip_first_sample');
|
||||
}
|
||||
}}
|
||||
@@ -841,7 +927,7 @@ export default function SimpleJob({
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.disable_sampling');
|
||||
// cannot do both, so disable the other
|
||||
if (value){
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.force_first_sample');
|
||||
}
|
||||
}}
|
||||
@@ -866,6 +952,113 @@ export default function SimpleJob({
|
||||
placeholder="Enter prompt"
|
||||
required
|
||||
/>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mt-2">
|
||||
<TextInput
|
||||
label={`Width`}
|
||||
value={sample.width ? `${sample.width}` : ''}
|
||||
onChange={value => {
|
||||
// remove any non-numeric characters
|
||||
value = value.replace(/\D/g, '');
|
||||
if (value === '') {
|
||||
// remove the key from the config if empty
|
||||
let newConfig = objectCopy(jobConfig);
|
||||
if (newConfig.config.process[0].sample.samples[i]) {
|
||||
delete newConfig.config.process[0].sample.samples[i].width;
|
||||
setJobConfig(
|
||||
newConfig.config.process[0].sample.samples,
|
||||
'config.process[0].sample.samples',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const intValue = parseInt(value);
|
||||
if (!isNaN(intValue)) {
|
||||
setJobConfig(intValue, `config.process[0].sample.samples[${i}].width`);
|
||||
} else {
|
||||
console.warn('Invalid width value:', value);
|
||||
}
|
||||
}
|
||||
}}
|
||||
placeholder={`${jobConfig.config.process[0].sample.width} (default)`}
|
||||
/>
|
||||
<TextInput
|
||||
label={`Height`}
|
||||
value={sample.height ? `${sample.height}` : ''}
|
||||
onChange={value => {
|
||||
// remove any non-numeric characters
|
||||
value = value.replace(/\D/g, '');
|
||||
if (value === '') {
|
||||
// remove the key from the config if empty
|
||||
let newConfig = objectCopy(jobConfig);
|
||||
if (newConfig.config.process[0].sample.samples[i]) {
|
||||
delete newConfig.config.process[0].sample.samples[i].height;
|
||||
setJobConfig(
|
||||
newConfig.config.process[0].sample.samples,
|
||||
'config.process[0].sample.samples',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const intValue = parseInt(value);
|
||||
if (!isNaN(intValue)) {
|
||||
setJobConfig(intValue, `config.process[0].sample.samples[${i}].height`);
|
||||
} else {
|
||||
console.warn('Invalid height value:', value);
|
||||
}
|
||||
}
|
||||
}}
|
||||
placeholder={`${jobConfig.config.process[0].sample.height} (default)`}
|
||||
/>
|
||||
<TextInput
|
||||
label={`Seed`}
|
||||
value={sample.seed ? `${sample.seed}` : ''}
|
||||
onChange={value => {
|
||||
// remove any non-numeric characters
|
||||
value = value.replace(/\D/g, '');
|
||||
if (value === '') {
|
||||
// remove the key from the config if empty
|
||||
let newConfig = objectCopy(jobConfig);
|
||||
if (newConfig.config.process[0].sample.samples[i]) {
|
||||
delete newConfig.config.process[0].sample.samples[i].seed;
|
||||
setJobConfig(
|
||||
newConfig.config.process[0].sample.samples,
|
||||
'config.process[0].sample.samples',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const intValue = parseInt(value);
|
||||
if (!isNaN(intValue)) {
|
||||
setJobConfig(intValue, `config.process[0].sample.samples[${i}].seed`);
|
||||
} else {
|
||||
console.warn('Invalid seed value:', value);
|
||||
}
|
||||
}
|
||||
}}
|
||||
placeholder={`${jobConfig.config.process[0].sample.walk_seed ? jobConfig.config.process[0].sample.seed + i : jobConfig.config.process[0].sample.seed} (default)`}
|
||||
/>
|
||||
<TextInput
|
||||
label={`LoRA Scale`}
|
||||
value={sample.network_multiplier ? `${sample.network_multiplier}` : ''}
|
||||
onChange={value => {
|
||||
// remove any non-numeric, - or . characters
|
||||
value = value.replace(/[^0-9.-]/g, '');
|
||||
if (value === '') {
|
||||
// remove the key from the config if empty
|
||||
let newConfig = objectCopy(jobConfig);
|
||||
if (newConfig.config.process[0].sample.samples[i]) {
|
||||
delete newConfig.config.process[0].sample.samples[i].network_multiplier;
|
||||
setJobConfig(
|
||||
newConfig.config.process[0].sample.samples,
|
||||
'config.process[0].sample.samples',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// set it as a string
|
||||
setJobConfig(value, `config.process[0].sample.samples[${i}].network_multiplier`);
|
||||
return;
|
||||
}
|
||||
}}
|
||||
placeholder={`1.0 (default)`}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { JobConfig, DatasetConfig } from '@/types';
|
||||
import { JobConfig, DatasetConfig, SliderConfig } from '@/types';
|
||||
|
||||
export const defaultDatasetConfig: DatasetConfig = {
|
||||
folder_path: '/path/to/images/folder',
|
||||
@@ -20,13 +20,22 @@ export const defaultDatasetConfig: DatasetConfig = {
|
||||
flip_y: false,
|
||||
};
|
||||
|
||||
export const defaultSliderConfig: SliderConfig = {
|
||||
guidance_strength: 3.0,
|
||||
anchor_strength: 1.0,
|
||||
positive_prompt: 'person who is happy',
|
||||
negative_prompt: 'person who is sad',
|
||||
target_class: 'person',
|
||||
anchor_class: "",
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
job: 'extension',
|
||||
config: {
|
||||
name: 'my_first_lora_v1',
|
||||
process: [
|
||||
{
|
||||
type: 'ui_trainer',
|
||||
type: 'diffusion_trainer',
|
||||
training_folder: 'output',
|
||||
sqlite_db_path: './aitk_db.db',
|
||||
device: 'cuda',
|
||||
@@ -100,7 +109,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
height: 1024,
|
||||
samples: [
|
||||
{
|
||||
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background'
|
||||
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background',
|
||||
},
|
||||
{
|
||||
prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe',
|
||||
@@ -109,7 +118,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
|
||||
},
|
||||
{
|
||||
prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
|
||||
prompt:
|
||||
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
|
||||
},
|
||||
{
|
||||
prompt: 'a bear building a log cabin in the snow covered mountains',
|
||||
@@ -121,13 +131,15 @@ export const defaultJobConfig: JobConfig = {
|
||||
prompt: 'hipster man with a beard, building a chair, in a wood shop',
|
||||
},
|
||||
{
|
||||
prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
|
||||
prompt:
|
||||
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
|
||||
},
|
||||
{
|
||||
prompt: "a man holding a sign that says, 'this is a sign'",
|
||||
},
|
||||
{
|
||||
prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
|
||||
prompt:
|
||||
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
|
||||
},
|
||||
],
|
||||
neg: '',
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
import { GroupedSelectOption, SelectOption, JobConfig } from '@/types';
|
||||
import { defaultSliderConfig } from './jobConfig';
|
||||
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
type DisableableSections =
|
||||
| 'model.quantize'
|
||||
| 'train.timestep_type'
|
||||
| 'network.conv'
|
||||
| 'trigger_word'
|
||||
| 'train.diff_output_preservation'
|
||||
| 'slider';
|
||||
|
||||
type AdditionalSections =
|
||||
| 'datasets.control_path'
|
||||
| 'datasets.do_i2v'
|
||||
@@ -439,3 +447,33 @@ export const quantizationOptions: SelectOption[] = [
|
||||
];
|
||||
|
||||
export const defaultQtype = 'qfloat8';
|
||||
|
||||
interface JobTypeOption extends SelectOption {
|
||||
disableSections?: DisableableSections[];
|
||||
processSections?: string[];
|
||||
onActivate?: (config: JobConfig) => JobConfig;
|
||||
onDeactivate?: (config: JobConfig) => JobConfig;
|
||||
}
|
||||
|
||||
export const jobTypeOptions: JobTypeOption[] = [
|
||||
{
|
||||
value: 'diffusion_trainer',
|
||||
label: 'LoRA Trainer',
|
||||
disableSections: ['slider'],
|
||||
},
|
||||
{
|
||||
value: 'concept_slider',
|
||||
label: 'Concept Slider',
|
||||
disableSections: ['trigger_word', 'train.diff_output_preservation'],
|
||||
onActivate: (config: JobConfig) => {
|
||||
// add default slider config
|
||||
config.config.process[0].slider = { ...defaultSliderConfig };
|
||||
return config;
|
||||
},
|
||||
onDeactivate: (config: JobConfig) => {
|
||||
// remove slider config
|
||||
delete config.config.process[0].slider;
|
||||
return config;
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useSearchParams, useRouter } from 'next/navigation';
|
||||
import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig';
|
||||
import { jobTypeOptions } from './options';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
import { useNestedState } from '@/utils/hooks';
|
||||
@@ -144,6 +145,38 @@ export default function TrainingForm() {
|
||||
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
|
||||
</>
|
||||
)}
|
||||
{!showAdvancedView && (
|
||||
<>
|
||||
<div>
|
||||
<SelectInput
|
||||
value={`${jobConfig?.config.process[0].type}`}
|
||||
onChange={value => {
|
||||
// undo current job type changes
|
||||
const currentOption = jobTypeOptions.find(
|
||||
option => option.value === jobConfig?.config.process[0].type,
|
||||
);
|
||||
if (currentOption && currentOption.onDeactivate) {
|
||||
setJobConfig(currentOption.onDeactivate(objectCopy(jobConfig)));
|
||||
}
|
||||
const option = jobTypeOptions.find(option => option.value === value);
|
||||
if (option) {
|
||||
if (option.onActivate) {
|
||||
setJobConfig(option.onActivate(objectCopy(jobConfig)));
|
||||
}
|
||||
jobTypeOptions.forEach(opt => {
|
||||
if (opt.value !== option.value && opt.onDeactivate) {
|
||||
setJobConfig(opt.onDeactivate(objectCopy(jobConfig)));
|
||||
}
|
||||
});
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].type');
|
||||
}}
|
||||
options={jobTypeOptions}
|
||||
/>
|
||||
</div>
|
||||
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="pr-2">
|
||||
<Button
|
||||
|
||||
@@ -68,8 +68,8 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: Te
|
||||
TextInput.displayName = 'TextInput';
|
||||
|
||||
export interface NumberInputProps extends InputProps {
|
||||
value: number;
|
||||
onChange: (value: number) => void;
|
||||
value: number | null;
|
||||
onChange: (value: number | null) => void;
|
||||
min?: number;
|
||||
max?: number;
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ export interface ModelConfig {
|
||||
|
||||
export interface SampleItem {
|
||||
prompt: string;
|
||||
width?: number
|
||||
width?: number;
|
||||
height?: number;
|
||||
neg?: string;
|
||||
seed?: number;
|
||||
@@ -153,6 +153,7 @@ export interface SampleItem {
|
||||
num_frames?: number;
|
||||
ctrl_img?: string | null;
|
||||
ctrl_idx?: number;
|
||||
network_multiplier?: number;
|
||||
}
|
||||
|
||||
export interface SampleConfig {
|
||||
@@ -171,14 +172,24 @@ export interface SampleConfig {
|
||||
fps: number;
|
||||
}
|
||||
|
||||
export interface SliderConfig {
|
||||
guidance_strength?: number;
|
||||
anchor_strength?: number;
|
||||
positive_prompt?: string;
|
||||
negative_prompt?: string;
|
||||
target_class?: string;
|
||||
anchor_class?: string | null;
|
||||
}
|
||||
|
||||
export interface ProcessConfig {
|
||||
type: 'ui_trainer';
|
||||
type: string;
|
||||
sqlite_db_path?: string;
|
||||
training_folder: string;
|
||||
performance_log_every: number;
|
||||
trigger_word: string | null;
|
||||
device: string;
|
||||
network?: NetworkConfig;
|
||||
slider?: SliderConfig;
|
||||
save: SaveConfig;
|
||||
datasets: DatasetConfig[];
|
||||
train: TrainConfig;
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.5.8"
|
||||
VERSION = "0.5.9"
|
||||
Reference in New Issue
Block a user