diff --git a/README.md b/README.md index a7119538..f974dafc 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ cd ai-toolkit git submodule update --init --recursive python -m venv venv .\venv\Scripts\activate -pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 +pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124 pip install -r requirements.txt ``` diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index f0fdba68..feaf4345 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -14,8 +14,12 @@ class UITrainer(SDTrainer): def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): super(UITrainer, self).__init__(process_id, job, config, **kwargs) self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db") + if not os.path.exists(self.sqlite_db_path): + raise Exception(f"SQLite database not found at {self.sqlite_db_path}") print(f"Using SQLite database at {self.sqlite_db_path}") self.job_id = os.environ.get("AITK_JOB_ID", None) + self.job_id = self.job_id.strip() if self.job_id is not None else None + print(f"Job ID: \"{self.job_id}\"") if self.job_id is None: raise Exception("AITK_JOB_ID not set") self.is_stopping = False diff --git a/run.py b/run.py index d4ccda2a..9a2b7f19 100644 --- a/run.py +++ b/run.py @@ -95,6 +95,14 @@ def main(): if not args.recover: print_end_message(jobs_completed, jobs_failed) raise e + except KeyboardInterrupt as e: + try: + job.process[0].on_error(e) + except Exception as e2: + print_acc(f"Error running on_error: {e2}") + if not args.recover: + print_end_message(jobs_completed, jobs_failed) + raise e if __name__ == '__main__': diff --git a/ui/src/app/api/jobs/[jobID]/start/route.ts b/ui/src/app/api/jobs/[jobID]/start/route.ts index 88755ef4..5067e90d 100644 --- a/ui/src/app/api/jobs/[jobID]/start/route.ts +++ b/ui/src/app/api/jobs/[jobID]/start/route.ts @@ -6,6 +6,7 @@ import path from 'path'; import fs from 'fs'; import os from 'os'; import { getTrainingFolder, getHFToken } from '@/server/settings'; +const isWindows = process.platform === 'win32'; const prisma = new PrismaClient(); @@ -52,13 +53,13 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s let pythonPath = 'python'; // use .venv or venv if it exists if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) { - if (os.platform() === 'win32') { + if (isWindows) { pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe'); } else { pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python'); } } else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) { - if (os.platform() === 'win32') { + if (isWindows) { pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe'); } else { pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python'); @@ -80,23 +81,55 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s additionalEnv.HF_TOKEN = hfToken; } - // console.log( - // 'Spawning command:', - // `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`, - // ); + let cmd = `${pythonPath} ${runFilePath} ${configPath}`; + for (const key in additionalEnv) { + if (os.platform() === 'win32') { + cmd = `set ${key}=${additionalEnv[key]} && ${cmd}`; + } else { + cmd = `${key}=${additionalEnv[key]} ${cmd}`; + } + } + + console.log('Spawning command:', cmd); // start job - const subprocess = spawn(pythonPath, [runFilePath, configPath], { - detached: true, - stdio: 'ignore', - env: { - ...process.env, - ...additionalEnv, - }, - cwd: TOOLKIT_ROOT, - }); + if (isWindows) { + // For Windows, use 'cmd.exe' to open a new command window + const subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, runFilePath, configPath], { + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + windowsHide: false, + }); + + subprocess.unref(); + } else { + // For non-Windows platforms, use your original approach + const subprocess = spawn(pythonPath, [runFilePath, configPath], { + detached: true, + stdio: 'ignore', + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + }); + + subprocess.unref(); + } + // const subprocess = spawn(pythonPath, [runFilePath, configPath], { + // detached: true, + // stdio: 'ignore', + // env: { + // ...process.env, + // ...additionalEnv, + // }, + // cwd: TOOLKIT_ROOT, + // }); - subprocess.unref(); + // subprocess.unref(); return NextResponse.json(job); } diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index f4090a7c..05361cda 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -18,6 +18,8 @@ export const options = { 'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.is_flux': [true, false], 'config.process[0].train.bypass_guidance_embedding': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, }, { @@ -27,6 +29,8 @@ export const options = { 'config.process[0].model.quantize': [true, false], 'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.is_flux': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, }, { @@ -36,12 +40,16 @@ export const options = { 'config.process[0].model.quantize': [false, false], 'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.is_lumina2': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, }, { name_or_path: 'ostris/objective-reality', dev_only: true, defaults: { + 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], }, }, ], diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index dec65d43..2da1f3ca 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -362,7 +362,10 @@ export default function TrainingForm() { className="pt-2" value={jobConfig.config.process[0].train.noise_scheduler} onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')} - options={[{ value: 'flowmatch', label: 'FlowMatch' }]} + options={[ + { value: 'flowmatch', label: 'FlowMatch' }, + { value: 'ddpm', label: 'DDPM' }, + ]} />