diff --git a/extensions_built_in/concept_slider/ConceptSliderTrainer.py b/extensions_built_in/concept_slider/ConceptSliderTrainer.py new file mode 100644 index 00000000..eba1d74e --- /dev/null +++ b/extensions_built_in/concept_slider/ConceptSliderTrainer.py @@ -0,0 +1,277 @@ +from collections import OrderedDict +from typing import Optional + +import torch + +from extensions_built_in.sd_trainer.DiffusionTrainer import DiffusionTrainer +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.train_tools import get_torch_dtype + + +class ConceptSliderTrainerConfig: + def __init__(self, **kwargs): + self.guidance_strength: float = kwargs.get("guidance_strength", 3.0) + self.anchor_strength: float = kwargs.get("anchor_strength", 1.0) + self.positive_prompt: str = kwargs.get("positive_prompt", "") + self.negative_prompt: str = kwargs.get("negative_prompt", "") + self.target_class: str = kwargs.get("target_class", "") + self.anchor_class: Optional[str] = kwargs.get("anchor_class", None) + + +class ConceptSliderTrainer(DiffusionTrainer): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.do_guided_loss = True + + self.slider: ConceptSliderTrainerConfig = ConceptSliderTrainerConfig( + **self.config.get("slider", {}) + ) + + self.positive_prompt = self.slider.positive_prompt + self.positive_prompt_embeds: Optional[PromptEmbeds] = None + self.negative_prompt = self.slider.negative_prompt + self.negative_prompt_embeds: Optional[PromptEmbeds] = None + self.target_class = self.slider.target_class + self.target_class_embeds: Optional[PromptEmbeds] = None + self.anchor_class = self.slider.anchor_class + self.anchor_class_embeds: Optional[PromptEmbeds] = None + + def hook_before_train_loop(self): + # do this before calling parent as it unloads the text encoder if requested + if self.is_caching_text_embeddings: + # make sure model is on cpu for this part so we don't oom. + self.sd.unet.to("cpu") + + # cache unconditional embeds (blank prompt) + with torch.no_grad(): + self.positive_prompt_embeds = ( + self.sd.encode_prompt( + [self.positive_prompt], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + self.target_class_embeds = ( + self.sd.encode_prompt( + [self.target_class], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + self.negative_prompt_embeds = ( + self.sd.encode_prompt( + [self.negative_prompt], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + if self.anchor_class is not None: + self.anchor_class_embeds = ( + self.sd.encode_prompt( + [self.anchor_class], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + # call parent + super().hook_before_train_loop() + + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: "DataLoaderBatchDTO", + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs, + ): + # todo for embeddings, we need to run without trigger words + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + + + # do out prior preds first + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + self.sd.unet.eval() + noisy_latents = noisy_latents.to(self.device_torch, dtype=dtype).detach() + + batch_size = noisy_latents.shape[0] + + positive_embeds = concat_prompt_embeds( + [self.positive_prompt_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + target_class_embeds = concat_prompt_embeds( + [self.target_class_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + negative_embeds = concat_prompt_embeds( + [self.negative_prompt_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + + if self.anchor_class_embeds is not None: + anchor_embeds = concat_prompt_embeds( + [self.anchor_class_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + + if self.anchor_class_embeds is not None: + # if we have an anchor, do it + combo_embeds = concat_prompt_embeds( + [ + positive_embeds, + target_class_embeds, + negative_embeds, + anchor_embeds, + ] + ) + num_embeds = 4 + else: + combo_embeds = concat_prompt_embeds( + [positive_embeds, target_class_embeds, negative_embeds] + ) + num_embeds = 3 + + # do them in one batch, VRAM should handle it since we are no grad + combo_pred = self.sd.predict_noise( + latents=torch.cat([noisy_latents] * num_embeds, dim=0), + conditional_embeddings=combo_embeds, + timestep=torch.cat([timesteps] * num_embeds, dim=0), + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + positive_pred, neutral_pred, negative_pred, anchor_target = ( + combo_pred.chunk(4, dim=0) + ) + else: + anchor_target = None + positive_pred, neutral_pred, negative_pred = combo_pred.chunk(3, dim=0) + + # calculate the targets + guidance_scale = self.slider.guidance_strength + + # enhance_positive_target = neutral_pred + guidance_scale * ( + # positive_pred - negative_pred + # ) + # enhance_negative_target = neutral_pred + guidance_scale * ( + # negative_pred - positive_pred + # ) + # erase_negative_target = neutral_pred - guidance_scale * ( + # negative_pred - positive_pred + # ) + # erase_positive_target = neutral_pred - guidance_scale * ( + # positive_pred - negative_pred + # ) + + positive = (positive_pred - neutral_pred) - (negative_pred - neutral_pred) + negative = (negative_pred - neutral_pred) - (positive_pred - neutral_pred) + + enhance_positive_target = neutral_pred + guidance_scale * positive + enhance_negative_target = neutral_pred + guidance_scale * negative + erase_negative_target = neutral_pred - guidance_scale * negative + erase_positive_target = neutral_pred - guidance_scale * positive + + if was_unet_training: + self.sd.unet.train() + + # restore network + if self.network is not None: + self.network.is_active = was_network_active + + if self.anchor_class_embeds is not None: + # do a grad inference with our target prompt + embeds = concat_prompt_embeds([target_class_embeds, anchor_embeds]).to( + self.device_torch, dtype=dtype + ) + + noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=0).to( + self.device_torch, dtype=dtype + ) + timesteps = torch.cat([timesteps, timesteps], dim=0) + else: + embeds = target_class_embeds.to(self.device_torch, dtype=dtype) + + # do positive first + self.network.set_multiplier(1.0) + pred = self.sd.predict_noise( + latents=noisy_latents, + conditional_embeddings=embeds, + timestep=timesteps, + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + class_pred, anchor_pred = pred.chunk(2, dim=0) + else: + class_pred = pred + anchor_pred = None + + # enhance positive loss + enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_positive_target) + + erase_loss = torch.nn.functional.mse_loss(class_pred, erase_negative_target) + + if anchor_target is None: + anchor_loss = torch.zeros_like(erase_loss) + else: + anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target) + + anchor_loss = anchor_loss * self.slider.anchor_strength + + # send backward now because gradient checkpointing needs network polarity intact + total_pos_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0 + total_pos_loss.backward() + total_pos_loss = total_pos_loss.detach() + + # now do negative + self.network.set_multiplier(-1.0) + pred = self.sd.predict_noise( + latents=noisy_latents, + conditional_embeddings=embeds, + timestep=timesteps, + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + class_pred, anchor_pred = pred.chunk(2, dim=0) + else: + class_pred = pred + anchor_pred = None + + # enhance negative loss + enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_negative_target) + erase_loss = torch.nn.functional.mse_loss(class_pred, erase_positive_target) + + if anchor_target is None: + anchor_loss = torch.zeros_like(erase_loss) + else: + anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target) + anchor_loss = anchor_loss * self.slider.anchor_strength + total_neg_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0 + total_neg_loss.backward() + total_neg_loss = total_neg_loss.detach() + + self.network.set_multiplier(1.0) + + total_loss = (total_pos_loss + total_neg_loss) / 2.0 + + # add a grad so backward works right + total_loss.requires_grad_(True) + return total_loss diff --git a/extensions_built_in/concept_slider/__init__.py b/extensions_built_in/concept_slider/__init__.py new file mode 100644 index 00000000..7a624b18 --- /dev/null +++ b/extensions_built_in/concept_slider/__init__.py @@ -0,0 +1,26 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class ConceptSliderTrainerTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "concept_slider" + + # name is the name of the extension for printing + name = "Concept Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ConceptSliderTrainer import ConceptSliderTrainer + + return ConceptSliderTrainer + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ConceptSliderTrainerTrainer +] diff --git a/extensions_built_in/sd_trainer/DiffusionTrainer.py b/extensions_built_in/sd_trainer/DiffusionTrainer.py new file mode 100644 index 00000000..730a1aee --- /dev/null +++ b/extensions_built_in/sd_trainer/DiffusionTrainer.py @@ -0,0 +1,297 @@ +from collections import OrderedDict +import os +import sqlite3 +import asyncio +import concurrent.futures +from extensions_built_in.sd_trainer.SDTrainer import SDTrainer +from typing import Literal, Optional +import threading +import time +import signal + +AITK_Status = Literal["running", "stopped", "error", "completed"] + + +class DiffusionTrainer(SDTrainer): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super(DiffusionTrainer, self).__init__(process_id, job, config, **kwargs) + self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db") + self.job_id = os.environ.get("AITK_JOB_ID", None) + self.job_id = self.job_id.strip() if self.job_id is not None else None + self.is_ui_trainer = True + if not os.path.exists(self.sqlite_db_path): + self.is_ui_trainer = False + else: + print(f"Using SQLite database at {self.sqlite_db_path}") + if self.job_id is None: + self.is_ui_trainer = False + else: + print(f"Job ID: \"{self.job_id}\"") + + if self.is_ui_trainer: + self.is_stopping = False + # Create a thread pool for database operations + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Track all async tasks + self._async_tasks = [] + # Initialize the status + self._run_async_operation(self._update_status("running", "Starting")) + self._stop_watcher_started = False + # self.start_stop_watcher(interval_sec=2.0) + + def start_stop_watcher(self, interval_sec: float = 5.0): + """ + Start a daemon thread that periodically checks should_stop() + and terminates the process immediately when triggered. + """ + if not self.is_ui_trainer: + return + if getattr(self, "_stop_watcher_started", False): + return + self._stop_watcher_started = True + t = threading.Thread( + target=self._stop_watcher_thread, args=(interval_sec,), daemon=True + ) + t.start() + + def _stop_watcher_thread(self, interval_sec: float): + while True: + try: + if self.should_stop(): + # Mark and update status (non-blocking; uses existing infra) + self.is_stopping = True + self._run_async_operation( + self._update_status("stopped", "Job stopped (remote)") + ) + # Best-effort flush pending async ops + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + # Try to stop DB thread pool quickly + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + print("") + print("****************************************************") + print(" Stop signal received; terminating process. ") + print("****************************************************") + os.kill(os.getpid(), signal.SIGINT) + time.sleep(interval_sec) + except Exception: + time.sleep(interval_sec) + + def _run_async_operation(self, coro): + """Helper method to run an async coroutine and track the task.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a task and track it + if loop.is_running(): + task = asyncio.run_coroutine_threadsafe(coro, loop) + self._async_tasks.append(asyncio.wrap_future(task)) + else: + task = loop.create_task(coro) + self._async_tasks.append(task) + loop.run_until_complete(task) + + async def _execute_db_operation(self, operation_func): + """Execute a database operation in a separate thread to avoid blocking.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.thread_pool, operation_func) + + def _db_connect(self): + """Create a new connection for each operation to avoid locking.""" + conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0) + conn.isolation_level = None # Enable autocommit mode + return conn + + def should_stop(self): + if not self.is_ui_trainer: + return False + def _check_stop(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT stop FROM Job WHERE id = ?", (self.job_id,)) + stop = cursor.fetchone() + return False if stop is None else stop[0] == 1 + + return _check_stop() + + def maybe_stop(self): + if not self.is_ui_trainer: + return + if self.should_stop(): + self._run_async_operation( + self._update_status("stopped", "Job stopped")) + self.is_stopping = True + raise Exception("Job stopped") + + async def _update_key(self, key, value): + if not self.accelerator.is_main_process: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + # Convert the value to string if it's not already + if isinstance(value, str): + value_to_insert = value + else: + value_to_insert = str(value) + + # Use parameterized query for both the column name and value + update_query = f"UPDATE Job SET {key} = ? WHERE id = ?" + cursor.execute( + update_query, (value_to_insert, self.job_id)) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_step(self): + """Non-blocking update of the step count.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_key("step", self.step_num)) + + def update_db_key(self, key, value): + """Non-blocking update a key in the database.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_key(key, value)) + + async def _update_status(self, status: AITK_Status, info: Optional[str] = None): + if not self.accelerator.is_main_process or not self.is_ui_trainer: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + if info is not None: + cursor.execute( + "UPDATE Job SET status = ?, info = ? WHERE id = ?", + (status, info, self.job_id) + ) + else: + cursor.execute( + "UPDATE Job SET status = ? WHERE id = ?", + (status, self.job_id) + ) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_status(self, status: AITK_Status, info: Optional[str] = None): + """Non-blocking update of status.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_status(status, info)) + + async def wait_for_all_async(self): + """Wait for all tracked async operations to complete.""" + if not self._async_tasks: + return + + try: + await asyncio.gather(*self._async_tasks) + except Exception as e: + pass + finally: + # Clear the task list after completion + self._async_tasks.clear() + + def on_error(self, e: Exception): + super(DiffusionTrainer, self).on_error(e) + if self.is_ui_trainer: + if self.accelerator.is_main_process and not self.is_stopping: + self.update_status("error", str(e)) + self.update_db_key("step", self.last_save_step) + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) + + def handle_timing_print_hook(self, timing_dict): + if "train_loop" not in timing_dict: + print("train_loop not found in timing_dict", timing_dict) + return + seconds_per_iter = timing_dict["train_loop"] + # determine iter/sec or sec/iter + if seconds_per_iter < 1: + iters_per_sec = 1 / seconds_per_iter + self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec") + else: + self.update_db_key( + "speed_string", f"{seconds_per_iter:.2f} sec/iter") + + def done_hook(self): + super(DiffusionTrainer, self).done_hook() + if self.is_ui_trainer: + self.update_status("completed", "Training completed") + # Wait for all async operations to finish before shutting down + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) + + def end_step_hook(self): + super(DiffusionTrainer, self).end_step_hook() + if self.is_ui_trainer: + self.update_step() + self.maybe_stop() + + def hook_before_model_load(self): + super().hook_before_model_load() + if self.is_ui_trainer: + self.maybe_stop() + self.update_status("running", "Loading model") + + def before_dataset_load(self): + super().before_dataset_load() + if self.is_ui_trainer: + self.maybe_stop() + self.update_status("running", "Loading dataset") + + def hook_before_train_loop(self): + super().hook_before_train_loop() + if self.is_ui_trainer: + self.maybe_stop() + self.update_step() + self.update_status("running", "Training") + self.timer.add_after_print_hook(self.handle_timing_print_hook) + + def status_update_hook_func(self, string): + self.update_status("running", string) + + def hook_after_sd_init_before_load(self): + super().hook_after_sd_init_before_load() + if self.is_ui_trainer: + self.maybe_stop() + self.sd.add_status_update_hook(self.status_update_hook_func) + + def sample_step_hook(self, img_num, total_imgs): + super().sample_step_hook(img_num, total_imgs) + if self.is_ui_trainer: + self.maybe_stop() + self.update_status( + "running", f"Generating images - {img_num + 1}/{total_imgs}") + + def sample(self, step=None, is_first=False): + self.maybe_stop() + total_imgs = len(self.sample_config.prompts) + self.update_status("running", f"Generating images - 0/{total_imgs}") + super().sample(step, is_first) + self.maybe_stop() + self.update_status("running", "Training") + + def save(self, step=None): + self.maybe_stop() + self.update_status("running", "Saving model") + super().save(step) + self.maybe_stop() + self.update_status("running", "Training") diff --git a/extensions_built_in/sd_trainer/__init__.py b/extensions_built_in/sd_trainer/__init__.py index 47c84fa1..065ff818 100644 --- a/extensions_built_in/sd_trainer/__init__.py +++ b/extensions_built_in/sd_trainer/__init__.py @@ -16,8 +16,10 @@ class SDTrainerExtension(Extension): def get_process(cls): # import your process class here so it is only loaded when needed and return it from .SDTrainer import SDTrainer + return SDTrainer + # This is for generic training (LoRA, Dreambooth, FineTuning) class UITrainerExtension(Extension): # uid must be unique, it is how the extension is identified @@ -32,9 +34,28 @@ class UITrainerExtension(Extension): def get_process(cls): # import your process class here so it is only loaded when needed and return it from .UITrainer import UITrainer + return UITrainer +# This is a universal trainer that can be from ui or api +class DiffusionTrainerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "diffusion_trainer" + + # name is the name of the extension for printing + name = "Diffusion Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .DiffusionTrainer import DiffusionTrainer + + return DiffusionTrainer + + # for backwards compatability class TextualInversionTrainer(SDTrainerExtension): uid = "textual_inversion_trainer" @@ -42,5 +63,8 @@ class TextualInversionTrainer(SDTrainerExtension): AI_TOOLKIT_EXTENSIONS = [ # you can put a list of extensions here - SDTrainerExtension, TextualInversionTrainer, UITrainerExtension + SDTrainerExtension, + TextualInversionTrainer, + UITrainerExtension, + DiffusionTrainerExtension, ] diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 2526d86e..d1247184 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -57,6 +57,13 @@ class SampleItem: self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None) self.ctrl_idx: int = kwargs.get('ctrl_idx', 0) self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier) + # convert to a number if it is a string + if isinstance(self.network_multiplier, str): + try: + self.network_multiplier = float(self.network_multiplier) + except: + print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0") + self.network_multiplier = 1.0 class SampleConfig: diff --git a/ui/src/app/jobs/new/AdvancedJob.tsx b/ui/src/app/jobs/new/AdvancedJob.tsx index 6a0d4388..66938384 100644 --- a/ui/src/app/jobs/new/AdvancedJob.tsx +++ b/ui/src/app/jobs/new/AdvancedJob.tsx @@ -108,7 +108,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props // We have to ensure certain things are always set try { - parsed.config.process[0].type = 'ui_trainer'; + // parsed.config.process[0].type = 'ui_trainer'; parsed.config.process[0].sqlite_db_path = './aitk_db.db'; parsed.config.process[0].training_folder = settings.TRAINING_FOLDER; parsed.config.process[0].device = 'cuda'; diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index fb42599a..30b4c155 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -1,6 +1,13 @@ 'use client'; import { useMemo } from 'react'; -import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options'; +import { + modelArchs, + ModelArch, + groupedModelOptions, + quantizationOptions, + defaultQtype, + jobTypeOptions, +} from './options'; import { defaultDatasetConfig } from './jobConfig'; import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; import { objectCopy } from '@/utils/basic'; @@ -8,7 +15,7 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp import Card from '@/components/Card'; import { X } from 'lucide-react'; import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; -import {FlipHorizontal2, FlipVertical2} from "lucide-react" +import { FlipHorizontal2, FlipVertical2 } from 'lucide-react'; type Props = { jobConfig: JobConfig; @@ -39,6 +46,21 @@ export default function SimpleJob({ return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; }, [jobConfig.config.process[0].model.arch]); + const jobType = useMemo(() => { + return jobTypeOptions.find(j => j.value === jobConfig.config.process[0].type); + }, [jobConfig.config.process[0].type]); + + const disableSections = useMemo(() => { + let sections: string[] = []; + if (modelArch?.disableSections) { + sections = sections.concat(modelArch.disableSections); + } + if (jobType?.disableSections) { + sections = sections.concat(jobType.disableSections); + } + return sections; + }, [modelArch, jobType]); + const isVideoModel = !!(modelArch?.group === 'video'); const numTopCards = useMemo(() => { @@ -46,12 +68,14 @@ export default function SimpleJob({ if (modelArch?.additionalSections?.includes('model.multistage')) { count += 1; // add multistage card } - if (!modelArch?.disableSections?.includes('model.quantize')) { + if (!disableSections.includes('model.quantize')) { count += 1; // add quantization card } + if (!disableSections.includes('slider')) { + count += 1; // add slider card + } return count; - - }, [modelArch]); + }, [modelArch, disableSections]); let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; @@ -62,6 +86,20 @@ export default function SimpleJob({ topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; } + const numTrainingCols = useMemo(() => { + let count = 4; + if (!disableSections.includes('train.diff_output_preservation')) { + count += 1; + } + return count; + }, [disableSections]); + + let trainingBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6'; + + if (numTrainingCols == 5) { + trainingBarClass = 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6'; + } + const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0; if (!hasARA) { @@ -78,7 +116,7 @@ export default function SimpleJob({ let ARAs: SelectOption[] = []; if (modelArch.accuracyRecoveryAdapters) { for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) { - ARAs.push({ value, label }); + ARAs.push({ value, label }); } } if (ARAs.length > 0) { @@ -124,19 +162,21 @@ export default function SimpleJob({ onChange={value => setGpuIDs(value)} options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} /> - { - if (value?.trim() === '') { - value = null; - } - setJobConfig(value, 'config.process[0].trigger_word'); - }} - placeholder="" - required - /> + {disableSections.includes('trigger_word') ? null : ( + { + if (value?.trim() === '') { + value = null; + } + setJobConfig(value, 'config.process[0].trigger_word'); + }} + placeholder="" + required + /> + )} {/* Model Configuration Section */} @@ -223,7 +263,7 @@ export default function SimpleJob({ )} - {modelArch?.disableSections?.includes('model.quantize') ? null : ( + {disableSections.includes('model.quantize') ? null : ( setJobConfig(value, 'config.process[0].train.switch_boundary_every')} - placeholder="eg. 1" - docKey={'train.switch_boundary_every'} - min={1} - required - /> + label="Switch Every" + value={jobConfig.config.process[0].train.switch_boundary_every} + onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')} + placeholder="eg. 1" + docKey={'train.switch_boundary_every'} + min={1} + required + /> )} @@ -319,7 +359,7 @@ export default function SimpleJob({ max={1024} required /> - {modelArch?.disableSections?.includes('network.conv') ? null : ( + {disableSections.includes('network.conv') ? null : ( )} + {!disableSections.includes('slider') && ( + + setJobConfig(value, 'config.process[0].slider.target_class')} + placeholder="eg. person" + /> + setJobConfig(value, 'config.process[0].slider.positive_prompt')} + placeholder="eg. person who is happy" + /> + setJobConfig(value, 'config.process[0].slider.negative_prompt')} + placeholder="eg. person who is sad" + /> + setJobConfig(value, 'config.process[0].slider.anchor_class')} + placeholder="" + /> + + )}
-
+
- {modelArch?.disableSections?.includes('train.timestep_type') ? null : ( + {disableSections.includes('train.timestep_type') ? null : ( setJobConfig(value, 'config.process[0].train.timestep_type')} options={[ { value: 'sigmoid', label: 'Sigmoid' }, @@ -508,33 +580,39 @@ export default function SimpleJob({
- - setJobConfig(value, 'config.process[0].train.diff_output_preservation')} - /> - - {jobConfig.config.process[0].train.diff_output_preservation && ( + {disableSections.includes('train.diff_output_preservation') ? null : ( <> - - setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') - } - placeholder="eg. 1.0" - min={0} - /> - setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} - placeholder="eg. woman" - /> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation')} + /> + + {jobConfig.config.process[0].train.diff_output_preservation && ( + <> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') + } + placeholder="eg. 1.0" + min={0} + /> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_class') + } + placeholder="eg. woman" + /> + + )} )}
@@ -641,12 +719,20 @@ export default function SimpleJob({ Flip X } + label={ + <> + Flip X + + } checked={dataset.flip_x || false} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)} /> Flip Y } + label={ + <> + Flip Y + + } checked={dataset.flip_y || false} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)} /> @@ -812,7 +898,7 @@ export default function SimpleJob({ onChange={value => { setJobConfig(value, 'config.process[0].train.skip_first_sample'); // cannot do both, so disable the other - if (value){ + if (value) { setJobConfig(false, 'config.process[0].train.force_first_sample'); } }} @@ -827,7 +913,7 @@ export default function SimpleJob({ onChange={value => { setJobConfig(value, 'config.process[0].train.force_first_sample'); // cannot do both, so disable the other - if (value){ + if (value) { setJobConfig(false, 'config.process[0].train.skip_first_sample'); } }} @@ -841,7 +927,7 @@ export default function SimpleJob({ onChange={value => { setJobConfig(value, 'config.process[0].train.disable_sampling'); // cannot do both, so disable the other - if (value){ + if (value) { setJobConfig(false, 'config.process[0].train.force_first_sample'); } }} @@ -866,6 +952,113 @@ export default function SimpleJob({ placeholder="Enter prompt" required /> +
+ { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].width; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].width`); + } else { + console.warn('Invalid width value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.width} (default)`} + /> + { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].height; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].height`); + } else { + console.warn('Invalid height value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.height} (default)`} + /> + { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].seed; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].seed`); + } else { + console.warn('Invalid seed value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.walk_seed ? jobConfig.config.process[0].sample.seed + i : jobConfig.config.process[0].sample.seed} (default)`} + /> + { + // remove any non-numeric, - or . characters + value = value.replace(/[^0-9.-]/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].network_multiplier; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + // set it as a string + setJobConfig(value, `config.process[0].sample.samples[${i}].network_multiplier`); + return; + } + }} + placeholder={`1.0 (default)`} + /> +
{modelArch?.additionalSections?.includes('sample.ctrl_img') && ( diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 40dad0be..bcdbcb12 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -1,4 +1,4 @@ -import { JobConfig, DatasetConfig } from '@/types'; +import { JobConfig, DatasetConfig, SliderConfig } from '@/types'; export const defaultDatasetConfig: DatasetConfig = { folder_path: '/path/to/images/folder', @@ -20,13 +20,22 @@ export const defaultDatasetConfig: DatasetConfig = { flip_y: false, }; +export const defaultSliderConfig: SliderConfig = { + guidance_strength: 3.0, + anchor_strength: 1.0, + positive_prompt: 'person who is happy', + negative_prompt: 'person who is sad', + target_class: 'person', + anchor_class: "", +}; + export const defaultJobConfig: JobConfig = { job: 'extension', config: { name: 'my_first_lora_v1', process: [ { - type: 'ui_trainer', + type: 'diffusion_trainer', training_folder: 'output', sqlite_db_path: './aitk_db.db', device: 'cuda', @@ -100,7 +109,7 @@ export const defaultJobConfig: JobConfig = { height: 1024, samples: [ { - prompt: 'woman with red hair, playing chess at the park, bomb going off in the background' + prompt: 'woman with red hair, playing chess at the park, bomb going off in the background', }, { prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe', @@ -109,7 +118,8 @@ export const defaultJobConfig: JobConfig = { prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini', }, { - prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background', + prompt: + 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background', }, { prompt: 'a bear building a log cabin in the snow covered mountains', @@ -121,13 +131,15 @@ export const defaultJobConfig: JobConfig = { prompt: 'hipster man with a beard, building a chair, in a wood shop', }, { - prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop', + prompt: + 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop', }, { prompt: "a man holding a sign that says, 'this is a sign'", }, { - prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle', + prompt: + 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle', }, ], neg: '', diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 71fdc9d8..7735ae35 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -1,8 +1,16 @@ -import { GroupedSelectOption, SelectOption } from '@/types'; +import { GroupedSelectOption, SelectOption, JobConfig } from '@/types'; +import { defaultSliderConfig } from './jobConfig'; type Control = 'depth' | 'line' | 'pose' | 'inpaint'; -type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; +type DisableableSections = + | 'model.quantize' + | 'train.timestep_type' + | 'network.conv' + | 'trigger_word' + | 'train.diff_output_preservation' + | 'slider'; + type AdditionalSections = | 'datasets.control_path' | 'datasets.do_i2v' @@ -439,3 +447,33 @@ export const quantizationOptions: SelectOption[] = [ ]; export const defaultQtype = 'qfloat8'; + +interface JobTypeOption extends SelectOption { + disableSections?: DisableableSections[]; + processSections?: string[]; + onActivate?: (config: JobConfig) => JobConfig; + onDeactivate?: (config: JobConfig) => JobConfig; +} + +export const jobTypeOptions: JobTypeOption[] = [ + { + value: 'diffusion_trainer', + label: 'LoRA Trainer', + disableSections: ['slider'], + }, + { + value: 'concept_slider', + label: 'Concept Slider', + disableSections: ['trigger_word', 'train.diff_output_preservation'], + onActivate: (config: JobConfig) => { + // add default slider config + config.config.process[0].slider = { ...defaultSliderConfig }; + return config; + }, + onDeactivate: (config: JobConfig) => { + // remove slider config + delete config.config.process[0].slider; + return config; + }, + }, +]; diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index fb2b8546..b57495d8 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -3,6 +3,7 @@ import { useEffect, useState } from 'react'; import { useSearchParams, useRouter } from 'next/navigation'; import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig'; +import { jobTypeOptions } from './options'; import { JobConfig } from '@/types'; import { objectCopy } from '@/utils/basic'; import { useNestedState } from '@/utils/hooks'; @@ -144,6 +145,38 @@ export default function TrainingForm() {
)} + {!showAdvancedView && ( + <> +
+ { + // undo current job type changes + const currentOption = jobTypeOptions.find( + option => option.value === jobConfig?.config.process[0].type, + ); + if (currentOption && currentOption.onDeactivate) { + setJobConfig(currentOption.onDeactivate(objectCopy(jobConfig))); + } + const option = jobTypeOptions.find(option => option.value === value); + if (option) { + if (option.onActivate) { + setJobConfig(option.onActivate(objectCopy(jobConfig))); + } + jobTypeOptions.forEach(opt => { + if (opt.value !== option.value && opt.onDeactivate) { + setJobConfig(opt.onDeactivate(objectCopy(jobConfig))); + } + }); + } + setJobConfig(value, 'config.process[0].type'); + }} + options={jobTypeOptions} + /> +
+
+ + )}