mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Capture speed from the timer for the ui
This commit is contained in:
@@ -6,6 +6,7 @@ import concurrent.futures
|
||||
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
AITK_Status = Literal["running", "stopped", "error", "completed"]
|
||||
|
||||
|
||||
@@ -81,10 +82,16 @@ class UITrainer(SDTrainer):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
# Convert the value to string if it's not already
|
||||
if isinstance(value, str):
|
||||
value_to_insert = value
|
||||
else:
|
||||
value_to_insert = str(value)
|
||||
|
||||
# Use parameterized query for both the column name and value
|
||||
update_query = f"UPDATE Job SET {key} = ? WHERE id = ?"
|
||||
cursor.execute(
|
||||
f"UPDATE Job SET {key} = ? WHERE id = ?",
|
||||
(value, self.job_id)
|
||||
)
|
||||
update_query, (value_to_insert, self.job_id))
|
||||
finally:
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
@@ -95,13 +102,10 @@ class UITrainer(SDTrainer):
|
||||
if self.accelerator.is_main_process:
|
||||
self._run_async_operation(self._update_key("step", self.step_num))
|
||||
|
||||
|
||||
def update_db_key(self, key, value):
|
||||
"""Non-blocking update a key in the database."""
|
||||
if self.accelerator.is_main_process:
|
||||
self._run_async_operation(self._update_key(key, value))
|
||||
|
||||
|
||||
|
||||
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
if not self.accelerator.is_main_process:
|
||||
@@ -151,6 +155,19 @@ class UITrainer(SDTrainer):
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def handle_timing_print_hook(self, timing_dict):
|
||||
if "train_loop" not in timing_dict:
|
||||
print("train_loop not found in timing_dict", timing_dict)
|
||||
return
|
||||
seconds_per_iter = timing_dict["train_loop"]
|
||||
# determine iter/sec or sec/iter
|
||||
if seconds_per_iter < 1:
|
||||
iters_per_sec = 1 / seconds_per_iter
|
||||
self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec")
|
||||
else:
|
||||
self.update_db_key(
|
||||
"speed_string", f"{seconds_per_iter:.2f} sec/iter")
|
||||
|
||||
def done_hook(self):
|
||||
super(UITrainer, self).done_hook()
|
||||
self.update_status("completed", "Training completed")
|
||||
@@ -178,6 +195,7 @@ class UITrainer(SDTrainer):
|
||||
self.maybe_stop()
|
||||
self.update_step()
|
||||
self.update_status("running", "Training")
|
||||
self.timer.add_after_print_hook(self.handle_timing_print_hook)
|
||||
|
||||
def status_update_hook_func(self, string):
|
||||
self.update_status("running", string)
|
||||
@@ -206,4 +224,4 @@ class UITrainer(SDTrainer):
|
||||
self.update_status("running", "Saving model")
|
||||
super().save(step)
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
self.update_status("running", "Training")
|
||||
|
||||
@@ -700,7 +700,7 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
self.print_and_status_update("Loading vae")
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
@@ -709,7 +709,7 @@ class StableDiffusion:
|
||||
text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
|
||||
else:
|
||||
self.print_and_status_update("Loading t5")
|
||||
self.print_and_status_update("Loading T5")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
|
||||
torch_dtype=dtype)
|
||||
@@ -726,7 +726,7 @@ class StableDiffusion:
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading clip")
|
||||
self.print_and_status_update("Loading CLIP")
|
||||
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
@@ -9,6 +9,7 @@ class Timer:
|
||||
self.timers = OrderedDict()
|
||||
self.active_timers = {}
|
||||
self.current_timer = None # Used for the context manager functionality
|
||||
self._after_print_hooks = []
|
||||
|
||||
def start(self, timer_name):
|
||||
if timer_name not in self.timers:
|
||||
@@ -34,12 +35,20 @@ class Timer:
|
||||
if len(self.timers[timer_name]) > self.max_buffer:
|
||||
self.timers[timer_name].popleft()
|
||||
|
||||
def add_after_print_hook(self, hook):
|
||||
self._after_print_hooks.append(hook)
|
||||
|
||||
def print(self):
|
||||
print(f"\nTimer '{self.name}':")
|
||||
timing_dict = {}
|
||||
# sort by longest at top
|
||||
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
|
||||
avg_time = sum(timings) / len(timings)
|
||||
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
|
||||
timing_dict[timer_name] = avg_time
|
||||
|
||||
for hook in self._after_print_hooks:
|
||||
hook(timing_dict)
|
||||
|
||||
print('')
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
sqlite_db_path: './aitk_db.db',
|
||||
device: 'cuda:0',
|
||||
trigger_word: null,
|
||||
performance_log_every: 10,
|
||||
network: {
|
||||
type: 'lora',
|
||||
linear: 16,
|
||||
|
||||
@@ -2,8 +2,8 @@ import { Job } from '@prisma/client';
|
||||
import useGPUInfo from '@/hooks/useGPUInfo';
|
||||
import GPUWidget from '@/components/GPUWidget';
|
||||
import FilesWidget from '@/components/FilesWidget';
|
||||
import { getJobConfig, getTotalSteps } from '@/utils/jobs';
|
||||
import { Cpu, HardDrive, Info } from 'lucide-react';
|
||||
import { getTotalSteps } from '@/utils/jobs';
|
||||
import { Cpu, HardDrive, Info, Gauge } from 'lucide-react';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
interface JobOverviewProps {
|
||||
@@ -45,7 +45,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
||||
{/* Job Information Panel */}
|
||||
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800">
|
||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||
<h2 className="font-semibold text-gray-100">Job Details</h2>
|
||||
<h2 className="text-gray-100"><Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}</h2>
|
||||
<span className={`px-3 py-1 rounded-full text-sm ${getStatusColor(job.status)}`}>{job.status}</span>
|
||||
</div>
|
||||
|
||||
@@ -64,7 +64,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
||||
</div>
|
||||
|
||||
{/* Job Info Grid */}
|
||||
<div className="grid gap-4">
|
||||
<div className="grid gap-4 grid-cols-1 md:grid-cols-3">
|
||||
<div className="flex items-center space-x-4">
|
||||
<HardDrive className="w-5 h-5 text-blue-400" />
|
||||
<div>
|
||||
@@ -82,10 +82,10 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
||||
</div>
|
||||
|
||||
<div className="flex items-center space-x-4">
|
||||
<Info className="w-5 h-5 text-amber-400" />
|
||||
<Gauge className="w-5 h-5 text-green-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Additional Information</p>
|
||||
<p className="text-sm font-medium text-gray-200">{job.info}</p>
|
||||
<p className="text-xs text-gray-400">Speed</p>
|
||||
<p className="text-sm font-medium text-gray-200">{job.speed_string == "" ? "?" : job.speed_string}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -132,6 +132,7 @@ export interface ProcessConfig {
|
||||
type: 'ui_trainer';
|
||||
sqlite_db_path?: string;
|
||||
training_folder: string;
|
||||
performance_log_every: number;
|
||||
trigger_word: string | null;
|
||||
device: string;
|
||||
network?: NetworkConfig;
|
||||
|
||||
Reference in New Issue
Block a user