Added support for new concept slider training script to CLI and UI

This commit is contained in:
Jaret Burkett
2025-09-16 10:22:34 -06:00
parent 3666b112a8
commit 218f673e3d
13 changed files with 996 additions and 78 deletions

View File

@@ -0,0 +1,277 @@
from collections import OrderedDict
from typing import Optional
import torch
from extensions_built_in.sd_trainer.DiffusionTrainer import DiffusionTrainer
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.train_tools import get_torch_dtype
class ConceptSliderTrainerConfig:
def __init__(self, **kwargs):
self.guidance_strength: float = kwargs.get("guidance_strength", 3.0)
self.anchor_strength: float = kwargs.get("anchor_strength", 1.0)
self.positive_prompt: str = kwargs.get("positive_prompt", "")
self.negative_prompt: str = kwargs.get("negative_prompt", "")
self.target_class: str = kwargs.get("target_class", "")
self.anchor_class: Optional[str] = kwargs.get("anchor_class", None)
class ConceptSliderTrainer(DiffusionTrainer):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.do_guided_loss = True
self.slider: ConceptSliderTrainerConfig = ConceptSliderTrainerConfig(
**self.config.get("slider", {})
)
self.positive_prompt = self.slider.positive_prompt
self.positive_prompt_embeds: Optional[PromptEmbeds] = None
self.negative_prompt = self.slider.negative_prompt
self.negative_prompt_embeds: Optional[PromptEmbeds] = None
self.target_class = self.slider.target_class
self.target_class_embeds: Optional[PromptEmbeds] = None
self.anchor_class = self.slider.anchor_class
self.anchor_class_embeds: Optional[PromptEmbeds] = None
def hook_before_train_loop(self):
# do this before calling parent as it unloads the text encoder if requested
if self.is_caching_text_embeddings:
# make sure model is on cpu for this part so we don't oom.
self.sd.unet.to("cpu")
# cache unconditional embeds (blank prompt)
with torch.no_grad():
self.positive_prompt_embeds = (
self.sd.encode_prompt(
[self.positive_prompt],
)
.to(self.device_torch, dtype=self.sd.torch_dtype)
.detach()
)
self.target_class_embeds = (
self.sd.encode_prompt(
[self.target_class],
)
.to(self.device_torch, dtype=self.sd.torch_dtype)
.detach()
)
self.negative_prompt_embeds = (
self.sd.encode_prompt(
[self.negative_prompt],
)
.to(self.device_torch, dtype=self.sd.torch_dtype)
.detach()
)
if self.anchor_class is not None:
self.anchor_class_embeds = (
self.sd.encode_prompt(
[self.anchor_class],
)
.to(self.device_torch, dtype=self.sd.torch_dtype)
.detach()
)
# call parent
super().hook_before_train_loop()
def get_guided_loss(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: "DataLoaderBatchDTO",
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs,
):
# todo for embeddings, we need to run without trigger words
was_unet_training = self.sd.unet.training
was_network_active = False
if self.network is not None:
was_network_active = self.network.is_active
self.network.is_active = False
# do out prior preds first
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
self.sd.unet.eval()
noisy_latents = noisy_latents.to(self.device_torch, dtype=dtype).detach()
batch_size = noisy_latents.shape[0]
positive_embeds = concat_prompt_embeds(
[self.positive_prompt_embeds] * batch_size
).to(self.device_torch, dtype=dtype)
target_class_embeds = concat_prompt_embeds(
[self.target_class_embeds] * batch_size
).to(self.device_torch, dtype=dtype)
negative_embeds = concat_prompt_embeds(
[self.negative_prompt_embeds] * batch_size
).to(self.device_torch, dtype=dtype)
if self.anchor_class_embeds is not None:
anchor_embeds = concat_prompt_embeds(
[self.anchor_class_embeds] * batch_size
).to(self.device_torch, dtype=dtype)
if self.anchor_class_embeds is not None:
# if we have an anchor, do it
combo_embeds = concat_prompt_embeds(
[
positive_embeds,
target_class_embeds,
negative_embeds,
anchor_embeds,
]
)
num_embeds = 4
else:
combo_embeds = concat_prompt_embeds(
[positive_embeds, target_class_embeds, negative_embeds]
)
num_embeds = 3
# do them in one batch, VRAM should handle it since we are no grad
combo_pred = self.sd.predict_noise(
latents=torch.cat([noisy_latents] * num_embeds, dim=0),
conditional_embeddings=combo_embeds,
timestep=torch.cat([timesteps] * num_embeds, dim=0),
guidance_scale=1.0,
guidance_embedding_scale=1.0,
batch=batch,
)
if self.anchor_class_embeds is not None:
positive_pred, neutral_pred, negative_pred, anchor_target = (
combo_pred.chunk(4, dim=0)
)
else:
anchor_target = None
positive_pred, neutral_pred, negative_pred = combo_pred.chunk(3, dim=0)
# calculate the targets
guidance_scale = self.slider.guidance_strength
# enhance_positive_target = neutral_pred + guidance_scale * (
# positive_pred - negative_pred
# )
# enhance_negative_target = neutral_pred + guidance_scale * (
# negative_pred - positive_pred
# )
# erase_negative_target = neutral_pred - guidance_scale * (
# negative_pred - positive_pred
# )
# erase_positive_target = neutral_pred - guidance_scale * (
# positive_pred - negative_pred
# )
positive = (positive_pred - neutral_pred) - (negative_pred - neutral_pred)
negative = (negative_pred - neutral_pred) - (positive_pred - neutral_pred)
enhance_positive_target = neutral_pred + guidance_scale * positive
enhance_negative_target = neutral_pred + guidance_scale * negative
erase_negative_target = neutral_pred - guidance_scale * negative
erase_positive_target = neutral_pred - guidance_scale * positive
if was_unet_training:
self.sd.unet.train()
# restore network
if self.network is not None:
self.network.is_active = was_network_active
if self.anchor_class_embeds is not None:
# do a grad inference with our target prompt
embeds = concat_prompt_embeds([target_class_embeds, anchor_embeds]).to(
self.device_torch, dtype=dtype
)
noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=0).to(
self.device_torch, dtype=dtype
)
timesteps = torch.cat([timesteps, timesteps], dim=0)
else:
embeds = target_class_embeds.to(self.device_torch, dtype=dtype)
# do positive first
self.network.set_multiplier(1.0)
pred = self.sd.predict_noise(
latents=noisy_latents,
conditional_embeddings=embeds,
timestep=timesteps,
guidance_scale=1.0,
guidance_embedding_scale=1.0,
batch=batch,
)
if self.anchor_class_embeds is not None:
class_pred, anchor_pred = pred.chunk(2, dim=0)
else:
class_pred = pred
anchor_pred = None
# enhance positive loss
enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_positive_target)
erase_loss = torch.nn.functional.mse_loss(class_pred, erase_negative_target)
if anchor_target is None:
anchor_loss = torch.zeros_like(erase_loss)
else:
anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target)
anchor_loss = anchor_loss * self.slider.anchor_strength
# send backward now because gradient checkpointing needs network polarity intact
total_pos_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0
total_pos_loss.backward()
total_pos_loss = total_pos_loss.detach()
# now do negative
self.network.set_multiplier(-1.0)
pred = self.sd.predict_noise(
latents=noisy_latents,
conditional_embeddings=embeds,
timestep=timesteps,
guidance_scale=1.0,
guidance_embedding_scale=1.0,
batch=batch,
)
if self.anchor_class_embeds is not None:
class_pred, anchor_pred = pred.chunk(2, dim=0)
else:
class_pred = pred
anchor_pred = None
# enhance negative loss
enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_negative_target)
erase_loss = torch.nn.functional.mse_loss(class_pred, erase_positive_target)
if anchor_target is None:
anchor_loss = torch.zeros_like(erase_loss)
else:
anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target)
anchor_loss = anchor_loss * self.slider.anchor_strength
total_neg_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0
total_neg_loss.backward()
total_neg_loss = total_neg_loss.detach()
self.network.set_multiplier(1.0)
total_loss = (total_pos_loss + total_neg_loss) / 2.0
# add a grad so backward works right
total_loss.requires_grad_(True)
return total_loss

View File

@@ -0,0 +1,26 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension
# This is for generic training (LoRA, Dreambooth, FineTuning)
class ConceptSliderTrainerTrainer(Extension):
# uid must be unique, it is how the extension is identified
uid = "concept_slider"
# name is the name of the extension for printing
name = "Concept Slider Trainer"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .ConceptSliderTrainer import ConceptSliderTrainer
return ConceptSliderTrainer
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
ConceptSliderTrainerTrainer
]

View File

@@ -0,0 +1,297 @@
from collections import OrderedDict
import os
import sqlite3
import asyncio
import concurrent.futures
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
from typing import Literal, Optional
import threading
import time
import signal
AITK_Status = Literal["running", "stopped", "error", "completed"]
class DiffusionTrainer(SDTrainer):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super(DiffusionTrainer, self).__init__(process_id, job, config, **kwargs)
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
self.job_id = os.environ.get("AITK_JOB_ID", None)
self.job_id = self.job_id.strip() if self.job_id is not None else None
self.is_ui_trainer = True
if not os.path.exists(self.sqlite_db_path):
self.is_ui_trainer = False
else:
print(f"Using SQLite database at {self.sqlite_db_path}")
if self.job_id is None:
self.is_ui_trainer = False
else:
print(f"Job ID: \"{self.job_id}\"")
if self.is_ui_trainer:
self.is_stopping = False
# Create a thread pool for database operations
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Track all async tasks
self._async_tasks = []
# Initialize the status
self._run_async_operation(self._update_status("running", "Starting"))
self._stop_watcher_started = False
# self.start_stop_watcher(interval_sec=2.0)
def start_stop_watcher(self, interval_sec: float = 5.0):
"""
Start a daemon thread that periodically checks should_stop()
and terminates the process immediately when triggered.
"""
if not self.is_ui_trainer:
return
if getattr(self, "_stop_watcher_started", False):
return
self._stop_watcher_started = True
t = threading.Thread(
target=self._stop_watcher_thread, args=(interval_sec,), daemon=True
)
t.start()
def _stop_watcher_thread(self, interval_sec: float):
while True:
try:
if self.should_stop():
# Mark and update status (non-blocking; uses existing infra)
self.is_stopping = True
self._run_async_operation(
self._update_status("stopped", "Job stopped (remote)")
)
# Best-effort flush pending async ops
try:
asyncio.run(self.wait_for_all_async())
except RuntimeError:
pass
# Try to stop DB thread pool quickly
try:
self.thread_pool.shutdown(wait=False, cancel_futures=True)
except TypeError:
self.thread_pool.shutdown(wait=False)
print("")
print("****************************************************")
print(" Stop signal received; terminating process. ")
print("****************************************************")
os.kill(os.getpid(), signal.SIGINT)
time.sleep(interval_sec)
except Exception:
time.sleep(interval_sec)
def _run_async_operation(self, coro):
"""Helper method to run an async coroutine and track the task."""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No event loop exists, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a task and track it
if loop.is_running():
task = asyncio.run_coroutine_threadsafe(coro, loop)
self._async_tasks.append(asyncio.wrap_future(task))
else:
task = loop.create_task(coro)
self._async_tasks.append(task)
loop.run_until_complete(task)
async def _execute_db_operation(self, operation_func):
"""Execute a database operation in a separate thread to avoid blocking."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.thread_pool, operation_func)
def _db_connect(self):
"""Create a new connection for each operation to avoid locking."""
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
conn.isolation_level = None # Enable autocommit mode
return conn
def should_stop(self):
if not self.is_ui_trainer:
return False
def _check_stop():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT stop FROM Job WHERE id = ?", (self.job_id,))
stop = cursor.fetchone()
return False if stop is None else stop[0] == 1
return _check_stop()
def maybe_stop(self):
if not self.is_ui_trainer:
return
if self.should_stop():
self._run_async_operation(
self._update_status("stopped", "Job stopped"))
self.is_stopping = True
raise Exception("Job stopped")
async def _update_key(self, key, value):
if not self.accelerator.is_main_process:
return
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE")
try:
# Convert the value to string if it's not already
if isinstance(value, str):
value_to_insert = value
else:
value_to_insert = str(value)
# Use parameterized query for both the column name and value
update_query = f"UPDATE Job SET {key} = ? WHERE id = ?"
cursor.execute(
update_query, (value_to_insert, self.job_id))
finally:
cursor.execute("COMMIT")
await self._execute_db_operation(_do_update)
def update_step(self):
"""Non-blocking update of the step count."""
if self.accelerator.is_main_process and self.is_ui_trainer:
self._run_async_operation(self._update_key("step", self.step_num))
def update_db_key(self, key, value):
"""Non-blocking update a key in the database."""
if self.accelerator.is_main_process and self.is_ui_trainer:
self._run_async_operation(self._update_key(key, value))
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
if not self.accelerator.is_main_process or not self.is_ui_trainer:
return
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE")
try:
if info is not None:
cursor.execute(
"UPDATE Job SET status = ?, info = ? WHERE id = ?",
(status, info, self.job_id)
)
else:
cursor.execute(
"UPDATE Job SET status = ? WHERE id = ?",
(status, self.job_id)
)
finally:
cursor.execute("COMMIT")
await self._execute_db_operation(_do_update)
def update_status(self, status: AITK_Status, info: Optional[str] = None):
"""Non-blocking update of status."""
if self.accelerator.is_main_process and self.is_ui_trainer:
self._run_async_operation(self._update_status(status, info))
async def wait_for_all_async(self):
"""Wait for all tracked async operations to complete."""
if not self._async_tasks:
return
try:
await asyncio.gather(*self._async_tasks)
except Exception as e:
pass
finally:
# Clear the task list after completion
self._async_tasks.clear()
def on_error(self, e: Exception):
super(DiffusionTrainer, self).on_error(e)
if self.is_ui_trainer:
if self.accelerator.is_main_process and not self.is_stopping:
self.update_status("error", str(e))
self.update_db_key("step", self.last_save_step)
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def handle_timing_print_hook(self, timing_dict):
if "train_loop" not in timing_dict:
print("train_loop not found in timing_dict", timing_dict)
return
seconds_per_iter = timing_dict["train_loop"]
# determine iter/sec or sec/iter
if seconds_per_iter < 1:
iters_per_sec = 1 / seconds_per_iter
self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec")
else:
self.update_db_key(
"speed_string", f"{seconds_per_iter:.2f} sec/iter")
def done_hook(self):
super(DiffusionTrainer, self).done_hook()
if self.is_ui_trainer:
self.update_status("completed", "Training completed")
# Wait for all async operations to finish before shutting down
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def end_step_hook(self):
super(DiffusionTrainer, self).end_step_hook()
if self.is_ui_trainer:
self.update_step()
self.maybe_stop()
def hook_before_model_load(self):
super().hook_before_model_load()
if self.is_ui_trainer:
self.maybe_stop()
self.update_status("running", "Loading model")
def before_dataset_load(self):
super().before_dataset_load()
if self.is_ui_trainer:
self.maybe_stop()
self.update_status("running", "Loading dataset")
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.is_ui_trainer:
self.maybe_stop()
self.update_step()
self.update_status("running", "Training")
self.timer.add_after_print_hook(self.handle_timing_print_hook)
def status_update_hook_func(self, string):
self.update_status("running", string)
def hook_after_sd_init_before_load(self):
super().hook_after_sd_init_before_load()
if self.is_ui_trainer:
self.maybe_stop()
self.sd.add_status_update_hook(self.status_update_hook_func)
def sample_step_hook(self, img_num, total_imgs):
super().sample_step_hook(img_num, total_imgs)
if self.is_ui_trainer:
self.maybe_stop()
self.update_status(
"running", f"Generating images - {img_num + 1}/{total_imgs}")
def sample(self, step=None, is_first=False):
self.maybe_stop()
total_imgs = len(self.sample_config.prompts)
self.update_status("running", f"Generating images - 0/{total_imgs}")
super().sample(step, is_first)
self.maybe_stop()
self.update_status("running", "Training")
def save(self, step=None):
self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step)
self.maybe_stop()
self.update_status("running", "Training")

View File

@@ -16,8 +16,10 @@ class SDTrainerExtension(Extension):
def get_process(cls): def get_process(cls):
# import your process class here so it is only loaded when needed and return it # import your process class here so it is only loaded when needed and return it
from .SDTrainer import SDTrainer from .SDTrainer import SDTrainer
return SDTrainer return SDTrainer
# This is for generic training (LoRA, Dreambooth, FineTuning) # This is for generic training (LoRA, Dreambooth, FineTuning)
class UITrainerExtension(Extension): class UITrainerExtension(Extension):
# uid must be unique, it is how the extension is identified # uid must be unique, it is how the extension is identified
@@ -32,9 +34,28 @@ class UITrainerExtension(Extension):
def get_process(cls): def get_process(cls):
# import your process class here so it is only loaded when needed and return it # import your process class here so it is only loaded when needed and return it
from .UITrainer import UITrainer from .UITrainer import UITrainer
return UITrainer return UITrainer
# This is a universal trainer that can be from ui or api
class DiffusionTrainerExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "diffusion_trainer"
# name is the name of the extension for printing
name = "Diffusion Trainer"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .DiffusionTrainer import DiffusionTrainer
return DiffusionTrainer
# for backwards compatability # for backwards compatability
class TextualInversionTrainer(SDTrainerExtension): class TextualInversionTrainer(SDTrainerExtension):
uid = "textual_inversion_trainer" uid = "textual_inversion_trainer"
@@ -42,5 +63,8 @@ class TextualInversionTrainer(SDTrainerExtension):
AI_TOOLKIT_EXTENSIONS = [ AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here # you can put a list of extensions here
SDTrainerExtension, TextualInversionTrainer, UITrainerExtension SDTrainerExtension,
TextualInversionTrainer,
UITrainerExtension,
DiffusionTrainerExtension,
] ]

View File

@@ -57,6 +57,13 @@ class SampleItem:
self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None) self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None)
self.ctrl_idx: int = kwargs.get('ctrl_idx', 0) self.ctrl_idx: int = kwargs.get('ctrl_idx', 0)
self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier) self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier)
# convert to a number if it is a string
if isinstance(self.network_multiplier, str):
try:
self.network_multiplier = float(self.network_multiplier)
except:
print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0")
self.network_multiplier = 1.0
class SampleConfig: class SampleConfig:

View File

@@ -108,7 +108,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
// We have to ensure certain things are always set // We have to ensure certain things are always set
try { try {
parsed.config.process[0].type = 'ui_trainer'; // parsed.config.process[0].type = 'ui_trainer';
parsed.config.process[0].sqlite_db_path = './aitk_db.db'; parsed.config.process[0].sqlite_db_path = './aitk_db.db';
parsed.config.process[0].training_folder = settings.TRAINING_FOLDER; parsed.config.process[0].training_folder = settings.TRAINING_FOLDER;
parsed.config.process[0].device = 'cuda'; parsed.config.process[0].device = 'cuda';

View File

@@ -1,6 +1,13 @@
'use client'; 'use client';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options'; import {
modelArchs,
ModelArch,
groupedModelOptions,
quantizationOptions,
defaultQtype,
jobTypeOptions,
} from './options';
import { defaultDatasetConfig } from './jobConfig'; import { defaultDatasetConfig } from './jobConfig';
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
import { objectCopy } from '@/utils/basic'; import { objectCopy } from '@/utils/basic';
@@ -8,7 +15,7 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp
import Card from '@/components/Card'; import Card from '@/components/Card';
import { X } from 'lucide-react'; import { X } from 'lucide-react';
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
import {FlipHorizontal2, FlipVertical2} from "lucide-react" import { FlipHorizontal2, FlipVertical2 } from 'lucide-react';
type Props = { type Props = {
jobConfig: JobConfig; jobConfig: JobConfig;
@@ -39,6 +46,21 @@ export default function SimpleJob({
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
}, [jobConfig.config.process[0].model.arch]); }, [jobConfig.config.process[0].model.arch]);
const jobType = useMemo(() => {
return jobTypeOptions.find(j => j.value === jobConfig.config.process[0].type);
}, [jobConfig.config.process[0].type]);
const disableSections = useMemo(() => {
let sections: string[] = [];
if (modelArch?.disableSections) {
sections = sections.concat(modelArch.disableSections);
}
if (jobType?.disableSections) {
sections = sections.concat(jobType.disableSections);
}
return sections;
}, [modelArch, jobType]);
const isVideoModel = !!(modelArch?.group === 'video'); const isVideoModel = !!(modelArch?.group === 'video');
const numTopCards = useMemo(() => { const numTopCards = useMemo(() => {
@@ -46,12 +68,14 @@ export default function SimpleJob({
if (modelArch?.additionalSections?.includes('model.multistage')) { if (modelArch?.additionalSections?.includes('model.multistage')) {
count += 1; // add multistage card count += 1; // add multistage card
} }
if (!modelArch?.disableSections?.includes('model.quantize')) { if (!disableSections.includes('model.quantize')) {
count += 1; // add quantization card count += 1; // add quantization card
} }
if (!disableSections.includes('slider')) {
count += 1; // add slider card
}
return count; return count;
}, [modelArch, disableSections]);
}, [modelArch]);
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
@@ -62,6 +86,20 @@ export default function SimpleJob({
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
} }
const numTrainingCols = useMemo(() => {
let count = 4;
if (!disableSections.includes('train.diff_output_preservation')) {
count += 1;
}
return count;
}, [disableSections]);
let trainingBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6';
if (numTrainingCols == 5) {
trainingBarClass = 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6';
}
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0; const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
if (!hasARA) { if (!hasARA) {
@@ -78,7 +116,7 @@ export default function SimpleJob({
let ARAs: SelectOption[] = []; let ARAs: SelectOption[] = [];
if (modelArch.accuracyRecoveryAdapters) { if (modelArch.accuracyRecoveryAdapters) {
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) { for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
ARAs.push({ value, label }); ARAs.push({ value, label });
} }
} }
if (ARAs.length > 0) { if (ARAs.length > 0) {
@@ -124,19 +162,21 @@ export default function SimpleJob({
onChange={value => setGpuIDs(value)} onChange={value => setGpuIDs(value)}
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
/> />
<TextInput {disableSections.includes('trigger_word') ? null : (
label="Trigger Word" <TextInput
value={jobConfig.config.process[0].trigger_word || ''} label="Trigger Word"
docKey="config.process[0].trigger_word" value={jobConfig.config.process[0].trigger_word || ''}
onChange={(value: string | null) => { docKey="config.process[0].trigger_word"
if (value?.trim() === '') { onChange={(value: string | null) => {
value = null; if (value?.trim() === '') {
} value = null;
setJobConfig(value, 'config.process[0].trigger_word'); }
}} setJobConfig(value, 'config.process[0].trigger_word');
placeholder="" }}
required placeholder=""
/> required
/>
)}
</Card> </Card>
{/* Model Configuration Section */} {/* Model Configuration Section */}
@@ -223,7 +263,7 @@ export default function SimpleJob({
</FormGroup> </FormGroup>
)} )}
</Card> </Card>
{modelArch?.disableSections?.includes('model.quantize') ? null : ( {disableSections.includes('model.quantize') ? null : (
<Card title="Quantization"> <Card title="Quantization">
<SelectInput <SelectInput
label="Transformer" label="Transformer"
@@ -270,14 +310,14 @@ export default function SimpleJob({
/> />
</FormGroup> </FormGroup>
<NumberInput <NumberInput
label="Switch Every" label="Switch Every"
value={jobConfig.config.process[0].train.switch_boundary_every} value={jobConfig.config.process[0].train.switch_boundary_every}
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')} onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
placeholder="eg. 1" placeholder="eg. 1"
docKey={'train.switch_boundary_every'} docKey={'train.switch_boundary_every'}
min={1} min={1}
required required
/> />
</Card> </Card>
)} )}
<Card title="Target"> <Card title="Target">
@@ -319,7 +359,7 @@ export default function SimpleJob({
max={1024} max={1024}
required required
/> />
{modelArch?.disableSections?.includes('network.conv') ? null : ( {disableSections.includes('network.conv') ? null : (
<NumberInput <NumberInput
label="Conv Rank" label="Conv Rank"
value={jobConfig.config.process[0].network.conv} value={jobConfig.config.process[0].network.conv}
@@ -336,6 +376,38 @@ export default function SimpleJob({
</> </>
)} )}
</Card> </Card>
{!disableSections.includes('slider') && (
<Card title="Slider">
<TextInput
label="Target Class"
className=""
value={jobConfig.config.process[0].slider?.target_class ?? ''}
onChange={value => setJobConfig(value, 'config.process[0].slider.target_class')}
placeholder="eg. person"
/>
<TextInput
label="Positive Prompt"
className=""
value={jobConfig.config.process[0].slider?.positive_prompt ?? ''}
onChange={value => setJobConfig(value, 'config.process[0].slider.positive_prompt')}
placeholder="eg. person who is happy"
/>
<TextInput
label="Negative Prompt"
className=""
value={jobConfig.config.process[0].slider?.negative_prompt ?? ''}
onChange={value => setJobConfig(value, 'config.process[0].slider.negative_prompt')}
placeholder="eg. person who is sad"
/>
<TextInput
label="Anchor Class"
className=""
value={jobConfig.config.process[0].slider?.anchor_class ?? ''}
onChange={value => setJobConfig(value, 'config.process[0].slider.anchor_class')}
placeholder=""
/>
</Card>
)}
<Card title="Save"> <Card title="Save">
<SelectInput <SelectInput
label="Data Type" label="Data Type"
@@ -367,7 +439,7 @@ export default function SimpleJob({
</div> </div>
<div> <div>
<Card title="Training"> <Card title="Training">
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6"> <div className={trainingBarClass}>
<div> <div>
<NumberInput <NumberInput
label="Batch Size" label="Batch Size"
@@ -426,11 +498,11 @@ export default function SimpleJob({
/> />
</div> </div>
<div> <div>
{modelArch?.disableSections?.includes('train.timestep_type') ? null : ( {disableSections.includes('train.timestep_type') ? null : (
<SelectInput <SelectInput
label="Timestep Type" label="Timestep Type"
value={jobConfig.config.process[0].train.timestep_type} value={jobConfig.config.process[0].train.timestep_type}
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false} disabled={disableSections.includes('train.timestep_type') || false}
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')} onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
options={[ options={[
{ value: 'sigmoid', label: 'Sigmoid' }, { value: 'sigmoid', label: 'Sigmoid' },
@@ -508,33 +580,39 @@ export default function SimpleJob({
</FormGroup> </FormGroup>
</div> </div>
<div> <div>
<FormGroup label="Regularization"> {disableSections.includes('train.diff_output_preservation') ? null : (
<Checkbox
label="Differtial Output Preservation"
className="pt-1"
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
/>
</FormGroup>
{jobConfig.config.process[0].train.diff_output_preservation && (
<> <>
<NumberInput <FormGroup label="Regularization">
label="DOP Loss Multiplier" <Checkbox
className="pt-2" label="Differential Output Preservation"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number} className="pt-1"
onChange={value => checked={jobConfig.config.process[0].train.diff_output_preservation || false}
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
} />
placeholder="eg. 1.0" </FormGroup>
min={0} {jobConfig.config.process[0].train.diff_output_preservation && (
/> <>
<TextInput <NumberInput
label="DOP Preservation Class" label="DOP Loss Multiplier"
className="pt-2" className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string} value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} onChange={value =>
placeholder="eg. woman" setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
/> }
placeholder="eg. 1.0"
min={0}
/>
<TextInput
label="DOP Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
}
placeholder="eg. woman"
/>
</>
)}
</> </>
)} )}
</div> </div>
@@ -641,12 +719,20 @@ export default function SimpleJob({
</FormGroup> </FormGroup>
<FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2"> <FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
<Checkbox <Checkbox
label={<>Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" /></>} label={
<>
Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" />
</>
}
checked={dataset.flip_x || false} checked={dataset.flip_x || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)}
/> />
<Checkbox <Checkbox
label={<>Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" /></>} label={
<>
Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" />
</>
}
checked={dataset.flip_y || false} checked={dataset.flip_y || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)}
/> />
@@ -812,7 +898,7 @@ export default function SimpleJob({
onChange={value => { onChange={value => {
setJobConfig(value, 'config.process[0].train.skip_first_sample'); setJobConfig(value, 'config.process[0].train.skip_first_sample');
// cannot do both, so disable the other // cannot do both, so disable the other
if (value){ if (value) {
setJobConfig(false, 'config.process[0].train.force_first_sample'); setJobConfig(false, 'config.process[0].train.force_first_sample');
} }
}} }}
@@ -827,7 +913,7 @@ export default function SimpleJob({
onChange={value => { onChange={value => {
setJobConfig(value, 'config.process[0].train.force_first_sample'); setJobConfig(value, 'config.process[0].train.force_first_sample');
// cannot do both, so disable the other // cannot do both, so disable the other
if (value){ if (value) {
setJobConfig(false, 'config.process[0].train.skip_first_sample'); setJobConfig(false, 'config.process[0].train.skip_first_sample');
} }
}} }}
@@ -841,7 +927,7 @@ export default function SimpleJob({
onChange={value => { onChange={value => {
setJobConfig(value, 'config.process[0].train.disable_sampling'); setJobConfig(value, 'config.process[0].train.disable_sampling');
// cannot do both, so disable the other // cannot do both, so disable the other
if (value){ if (value) {
setJobConfig(false, 'config.process[0].train.force_first_sample'); setJobConfig(false, 'config.process[0].train.force_first_sample');
} }
}} }}
@@ -866,6 +952,113 @@ export default function SimpleJob({
placeholder="Enter prompt" placeholder="Enter prompt"
required required
/> />
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mt-2">
<TextInput
label={`Width`}
value={sample.width ? `${sample.width}` : ''}
onChange={value => {
// remove any non-numeric characters
value = value.replace(/\D/g, '');
if (value === '') {
// remove the key from the config if empty
let newConfig = objectCopy(jobConfig);
if (newConfig.config.process[0].sample.samples[i]) {
delete newConfig.config.process[0].sample.samples[i].width;
setJobConfig(
newConfig.config.process[0].sample.samples,
'config.process[0].sample.samples',
);
}
} else {
const intValue = parseInt(value);
if (!isNaN(intValue)) {
setJobConfig(intValue, `config.process[0].sample.samples[${i}].width`);
} else {
console.warn('Invalid width value:', value);
}
}
}}
placeholder={`${jobConfig.config.process[0].sample.width} (default)`}
/>
<TextInput
label={`Height`}
value={sample.height ? `${sample.height}` : ''}
onChange={value => {
// remove any non-numeric characters
value = value.replace(/\D/g, '');
if (value === '') {
// remove the key from the config if empty
let newConfig = objectCopy(jobConfig);
if (newConfig.config.process[0].sample.samples[i]) {
delete newConfig.config.process[0].sample.samples[i].height;
setJobConfig(
newConfig.config.process[0].sample.samples,
'config.process[0].sample.samples',
);
}
} else {
const intValue = parseInt(value);
if (!isNaN(intValue)) {
setJobConfig(intValue, `config.process[0].sample.samples[${i}].height`);
} else {
console.warn('Invalid height value:', value);
}
}
}}
placeholder={`${jobConfig.config.process[0].sample.height} (default)`}
/>
<TextInput
label={`Seed`}
value={sample.seed ? `${sample.seed}` : ''}
onChange={value => {
// remove any non-numeric characters
value = value.replace(/\D/g, '');
if (value === '') {
// remove the key from the config if empty
let newConfig = objectCopy(jobConfig);
if (newConfig.config.process[0].sample.samples[i]) {
delete newConfig.config.process[0].sample.samples[i].seed;
setJobConfig(
newConfig.config.process[0].sample.samples,
'config.process[0].sample.samples',
);
}
} else {
const intValue = parseInt(value);
if (!isNaN(intValue)) {
setJobConfig(intValue, `config.process[0].sample.samples[${i}].seed`);
} else {
console.warn('Invalid seed value:', value);
}
}
}}
placeholder={`${jobConfig.config.process[0].sample.walk_seed ? jobConfig.config.process[0].sample.seed + i : jobConfig.config.process[0].sample.seed} (default)`}
/>
<TextInput
label={`LoRA Scale`}
value={sample.network_multiplier ? `${sample.network_multiplier}` : ''}
onChange={value => {
// remove any non-numeric, - or . characters
value = value.replace(/[^0-9.-]/g, '');
if (value === '') {
// remove the key from the config if empty
let newConfig = objectCopy(jobConfig);
if (newConfig.config.process[0].sample.samples[i]) {
delete newConfig.config.process[0].sample.samples[i].network_multiplier;
setJobConfig(
newConfig.config.process[0].sample.samples,
'config.process[0].sample.samples',
);
}
} else {
// set it as a string
setJobConfig(value, `config.process[0].sample.samples[${i}].network_multiplier`);
return;
}
}}
placeholder={`1.0 (default)`}
/>
</div>
</div> </div>
{modelArch?.additionalSections?.includes('sample.ctrl_img') && ( {modelArch?.additionalSections?.includes('sample.ctrl_img') && (

View File

@@ -1,4 +1,4 @@
import { JobConfig, DatasetConfig } from '@/types'; import { JobConfig, DatasetConfig, SliderConfig } from '@/types';
export const defaultDatasetConfig: DatasetConfig = { export const defaultDatasetConfig: DatasetConfig = {
folder_path: '/path/to/images/folder', folder_path: '/path/to/images/folder',
@@ -20,13 +20,22 @@ export const defaultDatasetConfig: DatasetConfig = {
flip_y: false, flip_y: false,
}; };
export const defaultSliderConfig: SliderConfig = {
guidance_strength: 3.0,
anchor_strength: 1.0,
positive_prompt: 'person who is happy',
negative_prompt: 'person who is sad',
target_class: 'person',
anchor_class: "",
};
export const defaultJobConfig: JobConfig = { export const defaultJobConfig: JobConfig = {
job: 'extension', job: 'extension',
config: { config: {
name: 'my_first_lora_v1', name: 'my_first_lora_v1',
process: [ process: [
{ {
type: 'ui_trainer', type: 'diffusion_trainer',
training_folder: 'output', training_folder: 'output',
sqlite_db_path: './aitk_db.db', sqlite_db_path: './aitk_db.db',
device: 'cuda', device: 'cuda',
@@ -100,7 +109,7 @@ export const defaultJobConfig: JobConfig = {
height: 1024, height: 1024,
samples: [ samples: [
{ {
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background' prompt: 'woman with red hair, playing chess at the park, bomb going off in the background',
}, },
{ {
prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe', prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe',
@@ -109,7 +118,8 @@ export const defaultJobConfig: JobConfig = {
prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini', prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
}, },
{ {
prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background', prompt:
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
}, },
{ {
prompt: 'a bear building a log cabin in the snow covered mountains', prompt: 'a bear building a log cabin in the snow covered mountains',
@@ -121,13 +131,15 @@ export const defaultJobConfig: JobConfig = {
prompt: 'hipster man with a beard, building a chair, in a wood shop', prompt: 'hipster man with a beard, building a chair, in a wood shop',
}, },
{ {
prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop', prompt:
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
}, },
{ {
prompt: "a man holding a sign that says, 'this is a sign'", prompt: "a man holding a sign that says, 'this is a sign'",
}, },
{ {
prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle', prompt:
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
}, },
], ],
neg: '', neg: '',

View File

@@ -1,8 +1,16 @@
import { GroupedSelectOption, SelectOption } from '@/types'; import { GroupedSelectOption, SelectOption, JobConfig } from '@/types';
import { defaultSliderConfig } from './jobConfig';
type Control = 'depth' | 'line' | 'pose' | 'inpaint'; type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; type DisableableSections =
| 'model.quantize'
| 'train.timestep_type'
| 'network.conv'
| 'trigger_word'
| 'train.diff_output_preservation'
| 'slider';
type AdditionalSections = type AdditionalSections =
| 'datasets.control_path' | 'datasets.control_path'
| 'datasets.do_i2v' | 'datasets.do_i2v'
@@ -439,3 +447,33 @@ export const quantizationOptions: SelectOption[] = [
]; ];
export const defaultQtype = 'qfloat8'; export const defaultQtype = 'qfloat8';
interface JobTypeOption extends SelectOption {
disableSections?: DisableableSections[];
processSections?: string[];
onActivate?: (config: JobConfig) => JobConfig;
onDeactivate?: (config: JobConfig) => JobConfig;
}
export const jobTypeOptions: JobTypeOption[] = [
{
value: 'diffusion_trainer',
label: 'LoRA Trainer',
disableSections: ['slider'],
},
{
value: 'concept_slider',
label: 'Concept Slider',
disableSections: ['trigger_word', 'train.diff_output_preservation'],
onActivate: (config: JobConfig) => {
// add default slider config
config.config.process[0].slider = { ...defaultSliderConfig };
return config;
},
onDeactivate: (config: JobConfig) => {
// remove slider config
delete config.config.process[0].slider;
return config;
},
},
];

View File

@@ -3,6 +3,7 @@
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useSearchParams, useRouter } from 'next/navigation'; import { useSearchParams, useRouter } from 'next/navigation';
import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig'; import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig';
import { jobTypeOptions } from './options';
import { JobConfig } from '@/types'; import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic'; import { objectCopy } from '@/utils/basic';
import { useNestedState } from '@/utils/hooks'; import { useNestedState } from '@/utils/hooks';
@@ -144,6 +145,38 @@ export default function TrainingForm() {
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div> <div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
</> </>
)} )}
{!showAdvancedView && (
<>
<div>
<SelectInput
value={`${jobConfig?.config.process[0].type}`}
onChange={value => {
// undo current job type changes
const currentOption = jobTypeOptions.find(
option => option.value === jobConfig?.config.process[0].type,
);
if (currentOption && currentOption.onDeactivate) {
setJobConfig(currentOption.onDeactivate(objectCopy(jobConfig)));
}
const option = jobTypeOptions.find(option => option.value === value);
if (option) {
if (option.onActivate) {
setJobConfig(option.onActivate(objectCopy(jobConfig)));
}
jobTypeOptions.forEach(opt => {
if (opt.value !== option.value && opt.onDeactivate) {
setJobConfig(opt.onDeactivate(objectCopy(jobConfig)));
}
});
}
setJobConfig(value, 'config.process[0].type');
}}
options={jobTypeOptions}
/>
</div>
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
</>
)}
<div className="pr-2"> <div className="pr-2">
<Button <Button

View File

@@ -68,8 +68,8 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: Te
TextInput.displayName = 'TextInput'; TextInput.displayName = 'TextInput';
export interface NumberInputProps extends InputProps { export interface NumberInputProps extends InputProps {
value: number; value: number | null;
onChange: (value: number) => void; onChange: (value: number | null) => void;
min?: number; min?: number;
max?: number; max?: number;
} }

View File

@@ -143,7 +143,7 @@ export interface ModelConfig {
export interface SampleItem { export interface SampleItem {
prompt: string; prompt: string;
width?: number width?: number;
height?: number; height?: number;
neg?: string; neg?: string;
seed?: number; seed?: number;
@@ -153,6 +153,7 @@ export interface SampleItem {
num_frames?: number; num_frames?: number;
ctrl_img?: string | null; ctrl_img?: string | null;
ctrl_idx?: number; ctrl_idx?: number;
network_multiplier?: number;
} }
export interface SampleConfig { export interface SampleConfig {
@@ -171,14 +172,24 @@ export interface SampleConfig {
fps: number; fps: number;
} }
export interface SliderConfig {
guidance_strength?: number;
anchor_strength?: number;
positive_prompt?: string;
negative_prompt?: string;
target_class?: string;
anchor_class?: string | null;
}
export interface ProcessConfig { export interface ProcessConfig {
type: 'ui_trainer'; type: string;
sqlite_db_path?: string; sqlite_db_path?: string;
training_folder: string; training_folder: string;
performance_log_every: number; performance_log_every: number;
trigger_word: string | null; trigger_word: string | null;
device: string; device: string;
network?: NetworkConfig; network?: NetworkConfig;
slider?: SliderConfig;
save: SaveConfig; save: SaveConfig;
datasets: DatasetConfig[]; datasets: DatasetConfig[];
train: TrainConfig; train: TrainConfig;

View File

@@ -1 +1 @@
VERSION = "0.5.8" VERSION = "0.5.9"