From adcf884c0f006075cbeb51a1a4a863d02b102c0d Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 21 Feb 2025 05:53:35 -0700 Subject: [PATCH] Built out the ui trainer plugin with db comminication --- .gitignore | 3 +- extensions_built_in/sd_trainer/UITrainer.py | 150 ++++++++++++++++++++ extensions_built_in/sd_trainer/__init__.py | 18 ++- jobs/process/BaseProcess.py | 3 + jobs/process/BaseSDTrainProcess.py | 28 +++- requirements.txt | 3 +- run.py | 1 + toolkit/stable_diffusion_model.py | 10 ++ ui/package.json | 1 + ui/prisma/schema.prisma | 19 +-- ui/src/app/train/jobConfig.ts | 3 +- ui/src/components/formInputs.tsx | 25 +++- ui/src/types.ts | 3 +- 13 files changed, 245 insertions(+), 22 deletions(-) create mode 100644 extensions_built_in/sd_trainer/UITrainer.py diff --git a/.gitignore b/.gitignore index 62105eac..6f9c6c62 100644 --- a/.gitignore +++ b/.gitignore @@ -178,4 +178,5 @@ cython_debug/ /wandb .vscode/settings.json .DS_Store -._.DS_Store \ No newline at end of file +._.DS_Store +aitk_db.db \ No newline at end of file diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py new file mode 100644 index 00000000..dae73d56 --- /dev/null +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -0,0 +1,150 @@ +import os +import sqlite3 +import asyncio +import concurrent.futures +from extensions_built_in.sd_trainer.SDTrainer import SDTrainer +from typing import Literal, Optional + +AITK_Status = Literal["running", "stopped", "error", "completed"] + + +class UITrainer(SDTrainer): + def __init__(self): + super(UITrainer, self).__init__() + self.sqlite_db_path = self.config.get("sqlite_db_path", "data.sqlite") + self.job_id = os.environ.get("AITK_JOB_ID", None) + if self.job_id is None: + raise Exception("AITK_JOB_ID not set") + self.is_stopping = False + # Create a thread pool for database operations + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Initialize the status + asyncio.run(self._update_status("running", "Starting")) + + async def _execute_db_operation(self, operation_func): + """Execute a database operation in a separate thread to avoid blocking.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.thread_pool, operation_func) + + def _db_connect(self): + """Create a new connection for each operation to avoid locking.""" + conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0) + conn.isolation_level = None # Enable autocommit mode + return conn + + def should_stop(self): + def _check_stop(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("SELECT stop FROM jobs WHERE job_id = ?", (self.job_id,)) + stop = cursor.fetchone() + return False if stop is None else stop[0] == 1 + + # For this one we need a synchronous result, so we'll run it directly + return _check_stop() + + def maybe_stop(self): + if self.should_stop(): + asyncio.run(self._update_status("stopped", "Job stopped")) + self.is_stopping = True + raise Exception("Job stopped") + + async def _update_step(self): + if not self.accelerator.is_main_process: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock + try: + cursor.execute( + "UPDATE jobs SET step = ? WHERE job_id = ?", + (self.step_num, self.job_id) + ) + finally: + cursor.execute("COMMIT") # Release the lock + + await self._execute_db_operation(_do_update) + + def update_step(self): + """Non-blocking update of the step count.""" + if self.accelerator.is_main_process: + # Start the async operation without waiting for it + asyncio.create_task(self._update_step()) + + async def _update_status(self, status: AITK_Status, info: Optional[str] = None): + if not self.accelerator.is_main_process: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock + try: + if info is not None: + cursor.execute( + "UPDATE jobs SET status = ?, info = ? WHERE job_id = ?", + (status, info, self.job_id) + ) + else: + cursor.execute( + "UPDATE jobs SET status = ? WHERE job_id = ?", + (status, self.job_id) + ) + finally: + cursor.execute("COMMIT") # Release the lock + + await self._execute_db_operation(_do_update) + + def update_status(self, status: AITK_Status, info: Optional[str] = None): + """Non-blocking update of status.""" + if self.accelerator.is_main_process: + # Start the async operation without waiting for it + asyncio.create_task(self._update_status(status, info)) + + def on_error(self, e: Exception): + super(UITrainer, self).on_error(e) + if self.accelerator.is_main_process and not self.is_stopping: + self.update_status("error", str(e)) + + def done_hook(self): + super(UITrainer, self).done_hook() + self.update_status("completed", "Training completed") + # Make sure we clean up the thread pool + self.thread_pool.shutdown(wait=False) + + def end_step_hook(self): + super(UITrainer, self).end_step_hook() + self.update_step() + self.maybe_stop() + + def hook_before_model_load(self): + super().hook_before_model_load() + self.update_status("running", "Loading model") + + def before_dataset_load(self): + super().before_dataset_load() + self.update_status("running", "Loading dataset") + + def hook_before_train_loop(self): + super().hook_before_train_loop() + self.update_status("running", "Training") + + def sample_step_hook(self, img_num, total_imgs): + super().sample_step_hook(img_num, total_imgs) + # subtract a since this is called after the image is generated + self.update_status( + "running", f"Generating images - {img_num - 1} of {total_imgs}") + + def sample(self, step=None, is_first=False): + self.maybe_stop() + total_imgs = len(self.sample_config.prompts) + self.update_status("running", f"Generating images - 1 of {total_imgs}") + super().sample(step, is_first) + self.update_status("running", "Training") + + def save(self, step=None): + self.update_status("running", "Saving model") + super().save(step) + self.update_status("running", "Training") \ No newline at end of file diff --git a/extensions_built_in/sd_trainer/__init__.py b/extensions_built_in/sd_trainer/__init__.py index 45aa841e..47c84fa1 100644 --- a/extensions_built_in/sd_trainer/__init__.py +++ b/extensions_built_in/sd_trainer/__init__.py @@ -18,6 +18,22 @@ class SDTrainerExtension(Extension): from .SDTrainer import SDTrainer return SDTrainer +# This is for generic training (LoRA, Dreambooth, FineTuning) +class UITrainerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "ui_trainer" + + # name is the name of the extension for printing + name = "UI Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .UITrainer import UITrainer + return UITrainer + # for backwards compatability class TextualInversionTrainer(SDTrainerExtension): @@ -26,5 +42,5 @@ class TextualInversionTrainer(SDTrainerExtension): AI_TOOLKIT_EXTENSIONS = [ # you can put a list of extensions here - SDTrainerExtension, TextualInversionTrainer + SDTrainerExtension, TextualInversionTrainer, UITrainerExtension ] diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py index f0644607..c58724c9 100644 --- a/jobs/process/BaseProcess.py +++ b/jobs/process/BaseProcess.py @@ -24,6 +24,9 @@ class BaseProcess(object): self.performance_log_every = self.get_conf('performance_log_every', 0) print(json.dumps(self.config, indent=4)) + + def on_error(self, e: Exception): + pass def get_conf(self, key, default=None, required=False, as_type=None): # split key by '.' and recursively get the value diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e30ddae0..a180aec5 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -439,6 +439,12 @@ class BaseSDTrainProcess(BaseTrainProcess): def post_save_hook(self, save_path): # override in subclass pass + + def done_hook(self): + pass + + def end_step_hook(self): + pass def save(self, step=None): if not self.accelerator.is_main_process: @@ -648,6 +654,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.logger.start() self.prepare_accelerator() + def sample_step_hook(self, img_num, total_imgs): + pass def prepare_accelerator(self): # set some config @@ -1419,6 +1427,8 @@ class BaseSDTrainProcess(BaseTrainProcess): ) # run base sd process run self.sd.load_model() + + self.sd.add_after_sample_image_hook(self.after_sample_image_hook) dtype = get_torch_dtype(self.train_config.dtype) @@ -2091,6 +2101,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # update various steps self.step_num = step + 1 self.grad_accumulation_step += 1 + self.end_step_hook() ################################################################### @@ -2110,13 +2121,15 @@ class BaseSDTrainProcess(BaseTrainProcess): self.logger.finish() self.accelerator.end_training() - if self.save_config.push_to_hub: - if("HF_TOKEN" not in os.environ): - interpreter_login(new_session=False, write_permission=True) - self.push_to_hub( - repo_id=self.save_config.hf_repo_id, - private=self.save_config.hf_private - ) + if self.accelerator.is_main_process: + # push to hub + if self.save_config.push_to_hub: + if("HF_TOKEN" not in os.environ): + interpreter_login(new_session=False, write_permission=True) + self.push_to_hub( + repo_id=self.save_config.hf_repo_id, + private=self.save_config.hf_private + ) del ( self.sd, unet, @@ -2128,6 +2141,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ) flush() + self.done_hook() def push_to_hub( self, diff --git a/requirements.txt b/requirements.txt index 4040e760..abf9bc64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,4 +32,5 @@ sentencepiece huggingface_hub peft gradio -python-slugify \ No newline at end of file +python-slugify +sqlite3 \ No newline at end of file diff --git a/run.py b/run.py index 9a3e57fd..ce3553a9 100644 --- a/run.py +++ b/run.py @@ -88,6 +88,7 @@ def main(): except Exception as e: print_acc(f"Error running job: {e}") jobs_failed += 1 + job.process[0].on_error(e) if not args.recover: print_end_message(jobs_completed, jobs_failed) raise e diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 7a765770..4fe01f9d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -202,6 +202,7 @@ class StableDiffusion: # merge in and preview active with -1 weight self.invert_assistant_lora = False + self._after_sample_img_hooks = [] def load_model(self): if self.is_loaded: @@ -1032,6 +1033,14 @@ class StableDiffusion: self.refiner_unet = refiner.unet del refiner flush() + + def _after_sample_image(self, img_num, total_imgs): + # process all hooks + for hook in self._after_sample_img_hooks: + hook(img_num, total_imgs) + + def add_after_sample_image_hook(self, func): + self._after_sample_img_hooks.append(func) @torch.no_grad() def generate_images( @@ -1598,6 +1607,7 @@ class StableDiffusion: gen_config.save_image(img, i) gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) flush() if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): diff --git a/ui/package.json b/ui/package.json index e8f91700..c6b8b97a 100644 --- a/ui/package.json +++ b/ui/package.json @@ -7,6 +7,7 @@ "build": "next build", "start": "next start", "lint": "next lint", + "update_db": "npx prisma generate && npx prisma db push", "format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\"" }, "dependencies": { diff --git a/ui/prisma/schema.prisma b/ui/prisma/schema.prisma index 4f3380e3..76a56899 100644 --- a/ui/prisma/schema.prisma +++ b/ui/prisma/schema.prisma @@ -13,12 +13,15 @@ model Settings { value String } - model Training { - id String @id @default(uuid()) - name String - gpu_id Int - job_config String // JSON string - created_at DateTime @default(now()) - updated_at DateTime @updatedAt -} \ No newline at end of file + id String @id @default(uuid()) + name String + gpu_id Int + job_config String // JSON string + created_at DateTime @default(now()) + updated_at DateTime @updatedAt + status String @default("stopped") + stop Boolean @default(false) + step Int @default(0) + info String @default("") +} diff --git a/ui/src/app/train/jobConfig.ts b/ui/src/app/train/jobConfig.ts index fc97dcdf..38c1305c 100644 --- a/ui/src/app/train/jobConfig.ts +++ b/ui/src/app/train/jobConfig.ts @@ -19,8 +19,9 @@ export const defaultJobConfig: JobConfig = { name: 'my_first_flex_lora_v1', process: [ { - type: 'sd_trainer', + type: 'ui_trainer', training_folder: 'output', + sqlite_db_path: './aitk_db.db', device: 'cuda:0', network: { type: 'lora', diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index 5b4c3132..09afd0d0 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -51,10 +51,29 @@ export const NumberInput = (props: NumberInputProps) => { type="number" value={value} onChange={(e) => { - let value = Number(e.target.value); - if (isNaN(value)) value = 0; + // Use parseFloat instead of Number to properly handle decimal values + const rawValue = e.target.value; + + // Special handling for empty or partial inputs + if (rawValue === '' || rawValue === '-' || rawValue === '.') { + // For empty or partial inputs (like just a minus sign or decimal point), + // we need to maintain the raw input in the input field + // but pass a valid number to onChange + onChange(0); + return; + } + + let value = parseFloat(rawValue); + + // Handle NaN cases + if (isNaN(value)) { + value = 0; + } + + // Apply min/max constraints only for valid numbers if (min !== undefined && value < min) value = min; if (max !== undefined && value > max) value = max; + onChange(value); }} className={inputClasses} @@ -62,6 +81,8 @@ export const NumberInput = (props: NumberInputProps) => { required={required} min={min} max={max} + // Allow decimal points + step="any" /> ); diff --git a/ui/src/types.ts b/ui/src/types.ts index e0141558..e851c2f1 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -121,7 +121,8 @@ export interface SampleConfig { } export interface ProcessConfig { - type: 'sd_trainer'; + type: 'ui_trainer'; + sqlite_db_path?: string; training_folder: string; device: string; network?: NetworkConfig;