mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Spawn windows in an cmd terminal. Should be working now, but not sure on my system
This commit is contained in:
@@ -38,7 +38,7 @@ cd ai-toolkit
|
|||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
.\venv\Scripts\activate
|
.\venv\Scripts\activate
|
||||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,12 @@ class UITrainer(SDTrainer):
|
|||||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||||
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
|
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
|
||||||
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
|
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
|
||||||
|
if not os.path.exists(self.sqlite_db_path):
|
||||||
|
raise Exception(f"SQLite database not found at {self.sqlite_db_path}")
|
||||||
print(f"Using SQLite database at {self.sqlite_db_path}")
|
print(f"Using SQLite database at {self.sqlite_db_path}")
|
||||||
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
||||||
|
self.job_id = self.job_id.strip() if self.job_id is not None else None
|
||||||
|
print(f"Job ID: \"{self.job_id}\"")
|
||||||
if self.job_id is None:
|
if self.job_id is None:
|
||||||
raise Exception("AITK_JOB_ID not set")
|
raise Exception("AITK_JOB_ID not set")
|
||||||
self.is_stopping = False
|
self.is_stopping = False
|
||||||
|
|||||||
8
run.py
8
run.py
@@ -95,6 +95,14 @@ def main():
|
|||||||
if not args.recover:
|
if not args.recover:
|
||||||
print_end_message(jobs_completed, jobs_failed)
|
print_end_message(jobs_completed, jobs_failed)
|
||||||
raise e
|
raise e
|
||||||
|
except KeyboardInterrupt as e:
|
||||||
|
try:
|
||||||
|
job.process[0].on_error(e)
|
||||||
|
except Exception as e2:
|
||||||
|
print_acc(f"Error running on_error: {e2}")
|
||||||
|
if not args.recover:
|
||||||
|
print_end_message(jobs_completed, jobs_failed)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import path from 'path';
|
|||||||
import fs from 'fs';
|
import fs from 'fs';
|
||||||
import os from 'os';
|
import os from 'os';
|
||||||
import { getTrainingFolder, getHFToken } from '@/server/settings';
|
import { getTrainingFolder, getHFToken } from '@/server/settings';
|
||||||
|
const isWindows = process.platform === 'win32';
|
||||||
|
|
||||||
const prisma = new PrismaClient();
|
const prisma = new PrismaClient();
|
||||||
|
|
||||||
@@ -52,13 +53,13 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
|||||||
let pythonPath = 'python';
|
let pythonPath = 'python';
|
||||||
// use .venv or venv if it exists
|
// use .venv or venv if it exists
|
||||||
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
||||||
if (os.platform() === 'win32') {
|
if (isWindows) {
|
||||||
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
|
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
|
||||||
} else {
|
} else {
|
||||||
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
||||||
}
|
}
|
||||||
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
||||||
if (os.platform() === 'win32') {
|
if (isWindows) {
|
||||||
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
|
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
|
||||||
} else {
|
} else {
|
||||||
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
||||||
@@ -80,23 +81,55 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
|||||||
additionalEnv.HF_TOKEN = hfToken;
|
additionalEnv.HF_TOKEN = hfToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
// console.log(
|
let cmd = `${pythonPath} ${runFilePath} ${configPath}`;
|
||||||
// 'Spawning command:',
|
for (const key in additionalEnv) {
|
||||||
// `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`,
|
if (os.platform() === 'win32') {
|
||||||
// );
|
cmd = `set ${key}=${additionalEnv[key]} && ${cmd}`;
|
||||||
|
} else {
|
||||||
|
cmd = `${key}=${additionalEnv[key]} ${cmd}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('Spawning command:', cmd);
|
||||||
|
|
||||||
// start job
|
// start job
|
||||||
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
if (isWindows) {
|
||||||
detached: true,
|
// For Windows, use 'cmd.exe' to open a new command window
|
||||||
stdio: 'ignore',
|
const subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, runFilePath, configPath], {
|
||||||
env: {
|
env: {
|
||||||
...process.env,
|
...process.env,
|
||||||
...additionalEnv,
|
...additionalEnv,
|
||||||
},
|
},
|
||||||
cwd: TOOLKIT_ROOT,
|
cwd: TOOLKIT_ROOT,
|
||||||
});
|
windowsHide: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
subprocess.unref();
|
||||||
|
} else {
|
||||||
|
// For non-Windows platforms, use your original approach
|
||||||
|
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||||
|
detached: true,
|
||||||
|
stdio: 'ignore',
|
||||||
|
env: {
|
||||||
|
...process.env,
|
||||||
|
...additionalEnv,
|
||||||
|
},
|
||||||
|
cwd: TOOLKIT_ROOT,
|
||||||
|
});
|
||||||
|
|
||||||
|
subprocess.unref();
|
||||||
|
}
|
||||||
|
// const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||||
|
// detached: true,
|
||||||
|
// stdio: 'ignore',
|
||||||
|
// env: {
|
||||||
|
// ...process.env,
|
||||||
|
// ...additionalEnv,
|
||||||
|
// },
|
||||||
|
// cwd: TOOLKIT_ROOT,
|
||||||
|
// });
|
||||||
|
|
||||||
subprocess.unref();
|
// subprocess.unref();
|
||||||
|
|
||||||
return NextResponse.json(job);
|
return NextResponse.json(job);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ export const options = {
|
|||||||
'config.process[0].model.quantize_te': [true, false],
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
'config.process[0].model.is_flux': [true, false],
|
'config.process[0].model.is_flux': [true, false],
|
||||||
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
||||||
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -27,6 +29,8 @@ export const options = {
|
|||||||
'config.process[0].model.quantize': [true, false],
|
'config.process[0].model.quantize': [true, false],
|
||||||
'config.process[0].model.quantize_te': [true, false],
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
'config.process[0].model.is_flux': [true, false],
|
'config.process[0].model.is_flux': [true, false],
|
||||||
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -36,12 +40,16 @@ export const options = {
|
|||||||
'config.process[0].model.quantize': [false, false],
|
'config.process[0].model.quantize': [false, false],
|
||||||
'config.process[0].model.quantize_te': [true, false],
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
'config.process[0].model.is_lumina2': [true, false],
|
'config.process[0].model.is_lumina2': [true, false],
|
||||||
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name_or_path: 'ostris/objective-reality',
|
name_or_path: 'ostris/objective-reality',
|
||||||
dev_only: true,
|
dev_only: true,
|
||||||
defaults: {
|
defaults: {
|
||||||
|
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -362,7 +362,10 @@ export default function TrainingForm() {
|
|||||||
className="pt-2"
|
className="pt-2"
|
||||||
value={jobConfig.config.process[0].train.noise_scheduler}
|
value={jobConfig.config.process[0].train.noise_scheduler}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
||||||
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
|
options={[
|
||||||
|
{ value: 'flowmatch', label: 'FlowMatch' },
|
||||||
|
{ value: 'ddpm', label: 'DDPM' },
|
||||||
|
]}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
@@ -516,7 +519,10 @@ export default function TrainingForm() {
|
|||||||
className="pt-2"
|
className="pt-2"
|
||||||
value={jobConfig.config.process[0].sample.sampler}
|
value={jobConfig.config.process[0].sample.sampler}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
|
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
|
||||||
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
|
options={[
|
||||||
|
{ value: 'flowmatch', label: 'FlowMatch' },
|
||||||
|
{ value: 'ddpm', label: 'DDPM' },
|
||||||
|
]}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
|
|||||||
Reference in New Issue
Block a user