mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Built out the ui trainer plugin with db comminication
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -178,4 +178,5 @@ cython_debug/
|
||||
/wandb
|
||||
.vscode/settings.json
|
||||
.DS_Store
|
||||
._.DS_Store
|
||||
._.DS_Store
|
||||
aitk_db.db
|
||||
150
extensions_built_in/sd_trainer/UITrainer.py
Normal file
150
extensions_built_in/sd_trainer/UITrainer.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
|
||||
from typing import Literal, Optional
|
||||
|
||||
AITK_Status = Literal["running", "stopped", "error", "completed"]
|
||||
|
||||
|
||||
class UITrainer(SDTrainer):
|
||||
def __init__(self):
|
||||
super(UITrainer, self).__init__()
|
||||
self.sqlite_db_path = self.config.get("sqlite_db_path", "data.sqlite")
|
||||
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
||||
if self.job_id is None:
|
||||
raise Exception("AITK_JOB_ID not set")
|
||||
self.is_stopping = False
|
||||
# Create a thread pool for database operations
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
# Initialize the status
|
||||
asyncio.run(self._update_status("running", "Starting"))
|
||||
|
||||
async def _execute_db_operation(self, operation_func):
|
||||
"""Execute a database operation in a separate thread to avoid blocking."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.thread_pool, operation_func)
|
||||
|
||||
def _db_connect(self):
|
||||
"""Create a new connection for each operation to avoid locking."""
|
||||
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
|
||||
conn.isolation_level = None # Enable autocommit mode
|
||||
return conn
|
||||
|
||||
def should_stop(self):
|
||||
def _check_stop():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT stop FROM jobs WHERE job_id = ?", (self.job_id,))
|
||||
stop = cursor.fetchone()
|
||||
return False if stop is None else stop[0] == 1
|
||||
|
||||
# For this one we need a synchronous result, so we'll run it directly
|
||||
return _check_stop()
|
||||
|
||||
def maybe_stop(self):
|
||||
if self.should_stop():
|
||||
asyncio.run(self._update_status("stopped", "Job stopped"))
|
||||
self.is_stopping = True
|
||||
raise Exception("Job stopped")
|
||||
|
||||
async def _update_step(self):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
||||
try:
|
||||
cursor.execute(
|
||||
"UPDATE jobs SET step = ? WHERE job_id = ?",
|
||||
(self.step_num, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT") # Release the lock
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_step(self):
|
||||
"""Non-blocking update of the step count."""
|
||||
if self.accelerator.is_main_process:
|
||||
# Start the async operation without waiting for it
|
||||
asyncio.create_task(self._update_step())
|
||||
|
||||
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
||||
try:
|
||||
if info is not None:
|
||||
cursor.execute(
|
||||
"UPDATE jobs SET status = ?, info = ? WHERE job_id = ?",
|
||||
(status, info, self.job_id)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"UPDATE jobs SET status = ? WHERE job_id = ?",
|
||||
(status, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT") # Release the lock
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
"""Non-blocking update of status."""
|
||||
if self.accelerator.is_main_process:
|
||||
# Start the async operation without waiting for it
|
||||
asyncio.create_task(self._update_status(status, info))
|
||||
|
||||
def on_error(self, e: Exception):
|
||||
super(UITrainer, self).on_error(e)
|
||||
if self.accelerator.is_main_process and not self.is_stopping:
|
||||
self.update_status("error", str(e))
|
||||
|
||||
def done_hook(self):
|
||||
super(UITrainer, self).done_hook()
|
||||
self.update_status("completed", "Training completed")
|
||||
# Make sure we clean up the thread pool
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
|
||||
def end_step_hook(self):
|
||||
super(UITrainer, self).end_step_hook()
|
||||
self.update_step()
|
||||
self.maybe_stop()
|
||||
|
||||
def hook_before_model_load(self):
|
||||
super().hook_before_model_load()
|
||||
self.update_status("running", "Loading model")
|
||||
|
||||
def before_dataset_load(self):
|
||||
super().before_dataset_load()
|
||||
self.update_status("running", "Loading dataset")
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
self.update_status("running", "Training")
|
||||
|
||||
def sample_step_hook(self, img_num, total_imgs):
|
||||
super().sample_step_hook(img_num, total_imgs)
|
||||
# subtract a since this is called after the image is generated
|
||||
self.update_status(
|
||||
"running", f"Generating images - {img_num - 1} of {total_imgs}")
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
self.maybe_stop()
|
||||
total_imgs = len(self.sample_config.prompts)
|
||||
self.update_status("running", f"Generating images - 1 of {total_imgs}")
|
||||
super().sample(step, is_first)
|
||||
self.update_status("running", "Training")
|
||||
|
||||
def save(self, step=None):
|
||||
self.update_status("running", "Saving model")
|
||||
super().save(step)
|
||||
self.update_status("running", "Training")
|
||||
@@ -18,6 +18,22 @@ class SDTrainerExtension(Extension):
|
||||
from .SDTrainer import SDTrainer
|
||||
return SDTrainer
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class UITrainerExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "ui_trainer"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "UI Trainer"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .UITrainer import UITrainer
|
||||
return UITrainer
|
||||
|
||||
|
||||
# for backwards compatability
|
||||
class TextualInversionTrainer(SDTrainerExtension):
|
||||
@@ -26,5 +42,5 @@ class TextualInversionTrainer(SDTrainerExtension):
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
SDTrainerExtension, TextualInversionTrainer
|
||||
SDTrainerExtension, TextualInversionTrainer, UITrainerExtension
|
||||
]
|
||||
|
||||
@@ -24,6 +24,9 @@ class BaseProcess(object):
|
||||
self.performance_log_every = self.get_conf('performance_log_every', 0)
|
||||
|
||||
print(json.dumps(self.config, indent=4))
|
||||
|
||||
def on_error(self, e: Exception):
|
||||
pass
|
||||
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
# split key by '.' and recursively get the value
|
||||
|
||||
@@ -439,6 +439,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def post_save_hook(self, save_path):
|
||||
# override in subclass
|
||||
pass
|
||||
|
||||
def done_hook(self):
|
||||
pass
|
||||
|
||||
def end_step_hook(self):
|
||||
pass
|
||||
|
||||
def save(self, step=None):
|
||||
if not self.accelerator.is_main_process:
|
||||
@@ -648,6 +654,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.logger.start()
|
||||
self.prepare_accelerator()
|
||||
|
||||
def sample_step_hook(self, img_num, total_imgs):
|
||||
pass
|
||||
|
||||
def prepare_accelerator(self):
|
||||
# set some config
|
||||
@@ -1419,6 +1427,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
self.sd.add_after_sample_image_hook(self.after_sample_image_hook)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
@@ -2091,6 +2101,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# update various steps
|
||||
self.step_num = step + 1
|
||||
self.grad_accumulation_step += 1
|
||||
self.end_step_hook()
|
||||
|
||||
|
||||
###################################################################
|
||||
@@ -2110,13 +2121,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.logger.finish()
|
||||
self.accelerator.end_training()
|
||||
|
||||
if self.save_config.push_to_hub:
|
||||
if("HF_TOKEN" not in os.environ):
|
||||
interpreter_login(new_session=False, write_permission=True)
|
||||
self.push_to_hub(
|
||||
repo_id=self.save_config.hf_repo_id,
|
||||
private=self.save_config.hf_private
|
||||
)
|
||||
if self.accelerator.is_main_process:
|
||||
# push to hub
|
||||
if self.save_config.push_to_hub:
|
||||
if("HF_TOKEN" not in os.environ):
|
||||
interpreter_login(new_session=False, write_permission=True)
|
||||
self.push_to_hub(
|
||||
repo_id=self.save_config.hf_repo_id,
|
||||
private=self.save_config.hf_private
|
||||
)
|
||||
del (
|
||||
self.sd,
|
||||
unet,
|
||||
@@ -2128,6 +2141,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
flush()
|
||||
self.done_hook()
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
|
||||
@@ -32,4 +32,5 @@ sentencepiece
|
||||
huggingface_hub
|
||||
peft
|
||||
gradio
|
||||
python-slugify
|
||||
python-slugify
|
||||
sqlite3
|
||||
1
run.py
1
run.py
@@ -88,6 +88,7 @@ def main():
|
||||
except Exception as e:
|
||||
print_acc(f"Error running job: {e}")
|
||||
jobs_failed += 1
|
||||
job.process[0].on_error(e)
|
||||
if not args.recover:
|
||||
print_end_message(jobs_completed, jobs_failed)
|
||||
raise e
|
||||
|
||||
@@ -202,6 +202,7 @@ class StableDiffusion:
|
||||
|
||||
# merge in and preview active with -1 weight
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
@@ -1032,6 +1033,14 @@ class StableDiffusion:
|
||||
self.refiner_unet = refiner.unet
|
||||
del refiner
|
||||
flush()
|
||||
|
||||
def _after_sample_image(self, img_num, total_imgs):
|
||||
# process all hooks
|
||||
for hook in self._after_sample_img_hooks:
|
||||
hook(img_num, total_imgs)
|
||||
|
||||
def add_after_sample_image_hook(self, func):
|
||||
self._after_sample_img_hooks.append(func)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_images(
|
||||
@@ -1598,6 +1607,7 @@ class StableDiffusion:
|
||||
|
||||
gen_config.save_image(img, i)
|
||||
gen_config.log_image(img, i)
|
||||
self._after_sample_image(i, len(image_configs))
|
||||
flush()
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
"update_db": "npx prisma generate && npx prisma db push",
|
||||
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
|
||||
},
|
||||
"dependencies": {
|
||||
|
||||
@@ -13,12 +13,15 @@ model Settings {
|
||||
value String
|
||||
}
|
||||
|
||||
|
||||
model Training {
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
gpu_id Int
|
||||
job_config String // JSON string
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @updatedAt
|
||||
}
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
gpu_id Int
|
||||
job_config String // JSON string
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @updatedAt
|
||||
status String @default("stopped")
|
||||
stop Boolean @default(false)
|
||||
step Int @default(0)
|
||||
info String @default("")
|
||||
}
|
||||
|
||||
@@ -19,8 +19,9 @@ export const defaultJobConfig: JobConfig = {
|
||||
name: 'my_first_flex_lora_v1',
|
||||
process: [
|
||||
{
|
||||
type: 'sd_trainer',
|
||||
type: 'ui_trainer',
|
||||
training_folder: 'output',
|
||||
sqlite_db_path: './aitk_db.db',
|
||||
device: 'cuda:0',
|
||||
network: {
|
||||
type: 'lora',
|
||||
|
||||
@@ -51,10 +51,29 @@ export const NumberInput = (props: NumberInputProps) => {
|
||||
type="number"
|
||||
value={value}
|
||||
onChange={(e) => {
|
||||
let value = Number(e.target.value);
|
||||
if (isNaN(value)) value = 0;
|
||||
// Use parseFloat instead of Number to properly handle decimal values
|
||||
const rawValue = e.target.value;
|
||||
|
||||
// Special handling for empty or partial inputs
|
||||
if (rawValue === '' || rawValue === '-' || rawValue === '.') {
|
||||
// For empty or partial inputs (like just a minus sign or decimal point),
|
||||
// we need to maintain the raw input in the input field
|
||||
// but pass a valid number to onChange
|
||||
onChange(0);
|
||||
return;
|
||||
}
|
||||
|
||||
let value = parseFloat(rawValue);
|
||||
|
||||
// Handle NaN cases
|
||||
if (isNaN(value)) {
|
||||
value = 0;
|
||||
}
|
||||
|
||||
// Apply min/max constraints only for valid numbers
|
||||
if (min !== undefined && value < min) value = min;
|
||||
if (max !== undefined && value > max) value = max;
|
||||
|
||||
onChange(value);
|
||||
}}
|
||||
className={inputClasses}
|
||||
@@ -62,6 +81,8 @@ export const NumberInput = (props: NumberInputProps) => {
|
||||
required={required}
|
||||
min={min}
|
||||
max={max}
|
||||
// Allow decimal points
|
||||
step="any"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -121,7 +121,8 @@ export interface SampleConfig {
|
||||
}
|
||||
|
||||
export interface ProcessConfig {
|
||||
type: 'sd_trainer';
|
||||
type: 'ui_trainer';
|
||||
sqlite_db_path?: string;
|
||||
training_folder: string;
|
||||
device: string;
|
||||
network?: NetworkConfig;
|
||||
|
||||
Reference in New Issue
Block a user