Spawn windows in an cmd terminal. Should be working now, but not sure on my system

This commit is contained in:
Jaret Burkett
2025-02-24 08:54:56 -07:00
parent 093f14ac19
commit 440ba5fb3d
6 changed files with 78 additions and 19 deletions

View File

@@ -38,7 +38,7 @@ cd ai-toolkit
git submodule update --init --recursive
python -m venv venv
.\venv\Scripts\activate
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124
pip install -r requirements.txt
```

View File

@@ -14,8 +14,12 @@ class UITrainer(SDTrainer):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
if not os.path.exists(self.sqlite_db_path):
raise Exception(f"SQLite database not found at {self.sqlite_db_path}")
print(f"Using SQLite database at {self.sqlite_db_path}")
self.job_id = os.environ.get("AITK_JOB_ID", None)
self.job_id = self.job_id.strip() if self.job_id is not None else None
print(f"Job ID: \"{self.job_id}\"")
if self.job_id is None:
raise Exception("AITK_JOB_ID not set")
self.is_stopping = False

8
run.py
View File

@@ -95,6 +95,14 @@ def main():
if not args.recover:
print_end_message(jobs_completed, jobs_failed)
raise e
except KeyboardInterrupt as e:
try:
job.process[0].on_error(e)
except Exception as e2:
print_acc(f"Error running on_error: {e2}")
if not args.recover:
print_end_message(jobs_completed, jobs_failed)
raise e
if __name__ == '__main__':

View File

@@ -6,6 +6,7 @@ import path from 'path';
import fs from 'fs';
import os from 'os';
import { getTrainingFolder, getHFToken } from '@/server/settings';
const isWindows = process.platform === 'win32';
const prisma = new PrismaClient();
@@ -52,13 +53,13 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
let pythonPath = 'python';
// use .venv or venv if it exists
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
if (os.platform() === 'win32') {
if (isWindows) {
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
} else {
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
}
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
if (os.platform() === 'win32') {
if (isWindows) {
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
} else {
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
@@ -80,23 +81,55 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
additionalEnv.HF_TOKEN = hfToken;
}
// console.log(
// 'Spawning command:',
// `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`,
// );
let cmd = `${pythonPath} ${runFilePath} ${configPath}`;
for (const key in additionalEnv) {
if (os.platform() === 'win32') {
cmd = `set ${key}=${additionalEnv[key]} && ${cmd}`;
} else {
cmd = `${key}=${additionalEnv[key]} ${cmd}`;
}
}
console.log('Spawning command:', cmd);
// start job
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
detached: true,
stdio: 'ignore',
env: {
...process.env,
...additionalEnv,
},
cwd: TOOLKIT_ROOT,
});
if (isWindows) {
// For Windows, use 'cmd.exe' to open a new command window
const subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, runFilePath, configPath], {
env: {
...process.env,
...additionalEnv,
},
cwd: TOOLKIT_ROOT,
windowsHide: false,
});
subprocess.unref();
} else {
// For non-Windows platforms, use your original approach
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
detached: true,
stdio: 'ignore',
env: {
...process.env,
...additionalEnv,
},
cwd: TOOLKIT_ROOT,
});
subprocess.unref();
}
// const subprocess = spawn(pythonPath, [runFilePath, configPath], {
// detached: true,
// stdio: 'ignore',
// env: {
// ...process.env,
// ...additionalEnv,
// },
// cwd: TOOLKIT_ROOT,
// });
subprocess.unref();
// subprocess.unref();
return NextResponse.json(job);
}

View File

@@ -18,6 +18,8 @@ export const options = {
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_flux': [true, false],
'config.process[0].train.bypass_guidance_embedding': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
@@ -27,6 +29,8 @@ export const options = {
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_flux': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
@@ -36,12 +40,16 @@ export const options = {
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_lumina2': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
name_or_path: 'ostris/objective-reality',
dev_only: true,
defaults: {
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
},
},
],

View File

@@ -362,7 +362,10 @@ export default function TrainingForm() {
className="pt-2"
value={jobConfig.config.process[0].train.noise_scheduler}
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
options={[
{ value: 'flowmatch', label: 'FlowMatch' },
{ value: 'ddpm', label: 'DDPM' },
]}
/>
</div>
<div>
@@ -516,7 +519,10 @@ export default function TrainingForm() {
className="pt-2"
value={jobConfig.config.process[0].sample.sampler}
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
options={[
{ value: 'flowmatch', label: 'FlowMatch' },
{ value: 'ddpm', label: 'DDPM' },
]}
/>
</div>
<div>