Capture speed from the timer for the ui

This commit is contained in:
Jaret Burkett
2025-02-23 14:38:46 -07:00
parent 3e49337a58
commit 56d8d6bd81
6 changed files with 46 additions and 17 deletions

View File

@@ -6,6 +6,7 @@ import concurrent.futures
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
from typing import Literal, Optional
AITK_Status = Literal["running", "stopped", "error", "completed"]
@@ -81,10 +82,16 @@ class UITrainer(SDTrainer):
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE")
try:
# Convert the value to string if it's not already
if isinstance(value, str):
value_to_insert = value
else:
value_to_insert = str(value)
# Use parameterized query for both the column name and value
update_query = f"UPDATE Job SET {key} = ? WHERE id = ?"
cursor.execute(
f"UPDATE Job SET {key} = ? WHERE id = ?",
(value, self.job_id)
)
update_query, (value_to_insert, self.job_id))
finally:
cursor.execute("COMMIT")
@@ -95,13 +102,10 @@ class UITrainer(SDTrainer):
if self.accelerator.is_main_process:
self._run_async_operation(self._update_key("step", self.step_num))
def update_db_key(self, key, value):
"""Non-blocking update a key in the database."""
if self.accelerator.is_main_process:
self._run_async_operation(self._update_key(key, value))
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
if not self.accelerator.is_main_process:
@@ -151,6 +155,19 @@ class UITrainer(SDTrainer):
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def handle_timing_print_hook(self, timing_dict):
if "train_loop" not in timing_dict:
print("train_loop not found in timing_dict", timing_dict)
return
seconds_per_iter = timing_dict["train_loop"]
# determine iter/sec or sec/iter
if seconds_per_iter < 1:
iters_per_sec = 1 / seconds_per_iter
self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec")
else:
self.update_db_key(
"speed_string", f"{seconds_per_iter:.2f} sec/iter")
def done_hook(self):
super(UITrainer, self).done_hook()
self.update_status("completed", "Training completed")
@@ -178,6 +195,7 @@ class UITrainer(SDTrainer):
self.maybe_stop()
self.update_step()
self.update_status("running", "Training")
self.timer.add_after_print_hook(self.handle_timing_print_hook)
def status_update_hook_func(self, string):
self.update_status("running", string)
@@ -206,4 +224,4 @@ class UITrainer(SDTrainer):
self.update_status("running", "Saving model")
super().save(step)
self.maybe_stop()
self.update_status("running", "Training")
self.update_status("running", "Training")

View File

@@ -700,7 +700,7 @@ class StableDiffusion:
flush()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
self.print_and_status_update("Loading vae")
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
@@ -709,7 +709,7 @@ class StableDiffusion:
text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
else:
self.print_and_status_update("Loading t5")
self.print_and_status_update("Loading T5")
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
torch_dtype=dtype)
@@ -726,7 +726,7 @@ class StableDiffusion:
freeze(text_encoder_2)
flush()
self.print_and_status_update("Loading clip")
self.print_and_status_update("Loading CLIP")
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)

View File

@@ -9,6 +9,7 @@ class Timer:
self.timers = OrderedDict()
self.active_timers = {}
self.current_timer = None # Used for the context manager functionality
self._after_print_hooks = []
def start(self, timer_name):
if timer_name not in self.timers:
@@ -34,12 +35,20 @@ class Timer:
if len(self.timers[timer_name]) > self.max_buffer:
self.timers[timer_name].popleft()
def add_after_print_hook(self, hook):
self._after_print_hooks.append(hook)
def print(self):
print(f"\nTimer '{self.name}':")
timing_dict = {}
# sort by longest at top
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
avg_time = sum(timings) / len(timings)
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
timing_dict[timer_name] = avg_time
for hook in self._after_print_hooks:
hook(timing_dict)
print('')

View File

@@ -24,6 +24,7 @@ export const defaultJobConfig: JobConfig = {
sqlite_db_path: './aitk_db.db',
device: 'cuda:0',
trigger_word: null,
performance_log_every: 10,
network: {
type: 'lora',
linear: 16,

View File

@@ -2,8 +2,8 @@ import { Job } from '@prisma/client';
import useGPUInfo from '@/hooks/useGPUInfo';
import GPUWidget from '@/components/GPUWidget';
import FilesWidget from '@/components/FilesWidget';
import { getJobConfig, getTotalSteps } from '@/utils/jobs';
import { Cpu, HardDrive, Info } from 'lucide-react';
import { getTotalSteps } from '@/utils/jobs';
import { Cpu, HardDrive, Info, Gauge } from 'lucide-react';
import { useMemo } from 'react';
interface JobOverviewProps {
@@ -45,7 +45,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
{/* Job Information Panel */}
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800">
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
<h2 className="font-semibold text-gray-100">Job Details</h2>
<h2 className="text-gray-100"><Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}</h2>
<span className={`px-3 py-1 rounded-full text-sm ${getStatusColor(job.status)}`}>{job.status}</span>
</div>
@@ -64,7 +64,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
</div>
{/* Job Info Grid */}
<div className="grid gap-4">
<div className="grid gap-4 grid-cols-1 md:grid-cols-3">
<div className="flex items-center space-x-4">
<HardDrive className="w-5 h-5 text-blue-400" />
<div>
@@ -82,10 +82,10 @@ export default function JobOverview({ job }: JobOverviewProps) {
</div>
<div className="flex items-center space-x-4">
<Info className="w-5 h-5 text-amber-400" />
<Gauge className="w-5 h-5 text-green-400" />
<div>
<p className="text-xs text-gray-400">Additional Information</p>
<p className="text-sm font-medium text-gray-200">{job.info}</p>
<p className="text-xs text-gray-400">Speed</p>
<p className="text-sm font-medium text-gray-200">{job.speed_string == "" ? "?" : job.speed_string}</p>
</div>
</div>
</div>

View File

@@ -132,6 +132,7 @@ export interface ProcessConfig {
type: 'ui_trainer';
sqlite_db_path?: string;
training_folder: string;
performance_log_every: number;
trigger_word: string | null;
device: string;
network?: NetworkConfig;