From 56d8d6bd816b02c0c226c9d757be40c451f55d29 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 23 Feb 2025 14:38:46 -0700 Subject: [PATCH] Capture speed from the timer for the ui --- extensions_built_in/sd_trainer/UITrainer.py | 32 ++++++++++++++++----- toolkit/stable_diffusion_model.py | 6 ++-- toolkit/timer.py | 9 ++++++ ui/src/app/jobs/new/jobConfig.ts | 1 + ui/src/components/JobOverview.tsx | 14 ++++----- ui/src/types.ts | 1 + 6 files changed, 46 insertions(+), 17 deletions(-) diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index fa81c130..f0fdba68 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -6,6 +6,7 @@ import concurrent.futures from extensions_built_in.sd_trainer.SDTrainer import SDTrainer from typing import Literal, Optional + AITK_Status = Literal["running", "stopped", "error", "completed"] @@ -81,10 +82,16 @@ class UITrainer(SDTrainer): cursor = conn.cursor() cursor.execute("BEGIN IMMEDIATE") try: + # Convert the value to string if it's not already + if isinstance(value, str): + value_to_insert = value + else: + value_to_insert = str(value) + + # Use parameterized query for both the column name and value + update_query = f"UPDATE Job SET {key} = ? WHERE id = ?" cursor.execute( - f"UPDATE Job SET {key} = ? WHERE id = ?", - (value, self.job_id) - ) + update_query, (value_to_insert, self.job_id)) finally: cursor.execute("COMMIT") @@ -95,13 +102,10 @@ class UITrainer(SDTrainer): if self.accelerator.is_main_process: self._run_async_operation(self._update_key("step", self.step_num)) - def update_db_key(self, key, value): """Non-blocking update a key in the database.""" if self.accelerator.is_main_process: self._run_async_operation(self._update_key(key, value)) - - async def _update_status(self, status: AITK_Status, info: Optional[str] = None): if not self.accelerator.is_main_process: @@ -151,6 +155,19 @@ class UITrainer(SDTrainer): asyncio.run(self.wait_for_all_async()) self.thread_pool.shutdown(wait=True) + def handle_timing_print_hook(self, timing_dict): + if "train_loop" not in timing_dict: + print("train_loop not found in timing_dict", timing_dict) + return + seconds_per_iter = timing_dict["train_loop"] + # determine iter/sec or sec/iter + if seconds_per_iter < 1: + iters_per_sec = 1 / seconds_per_iter + self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec") + else: + self.update_db_key( + "speed_string", f"{seconds_per_iter:.2f} sec/iter") + def done_hook(self): super(UITrainer, self).done_hook() self.update_status("completed", "Training completed") @@ -178,6 +195,7 @@ class UITrainer(SDTrainer): self.maybe_stop() self.update_step() self.update_status("running", "Training") + self.timer.add_after_print_hook(self.handle_timing_print_hook) def status_update_hook_func(self, string): self.update_status("running", string) @@ -206,4 +224,4 @@ class UITrainer(SDTrainer): self.update_status("running", "Saving model") super().save(step) self.maybe_stop() - self.update_status("running", "Training") \ No newline at end of file + self.update_status("running", "Training") diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 40da202d..00b38574 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -700,7 +700,7 @@ class StableDiffusion: flush() scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") - self.print_and_status_update("Loading vae") + self.print_and_status_update("Loading VAE") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() @@ -709,7 +709,7 @@ class StableDiffusion: text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) else: - self.print_and_status_update("Loading t5") + self.print_and_status_update("Loading T5") tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) @@ -726,7 +726,7 @@ class StableDiffusion: freeze(text_encoder_2) flush() - self.print_and_status_update("Loading clip") + self.print_and_status_update("Loading CLIP") text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) diff --git a/toolkit/timer.py b/toolkit/timer.py index ca4fecba..3592ba5e 100644 --- a/toolkit/timer.py +++ b/toolkit/timer.py @@ -9,6 +9,7 @@ class Timer: self.timers = OrderedDict() self.active_timers = {} self.current_timer = None # Used for the context manager functionality + self._after_print_hooks = [] def start(self, timer_name): if timer_name not in self.timers: @@ -34,12 +35,20 @@ class Timer: if len(self.timers[timer_name]) > self.max_buffer: self.timers[timer_name].popleft() + def add_after_print_hook(self, hook): + self._after_print_hooks.append(hook) + def print(self): print(f"\nTimer '{self.name}':") + timing_dict = {} # sort by longest at top for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True): avg_time = sum(timings) / len(timings) print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}") + timing_dict[timer_name] = avg_time + + for hook in self._after_print_hooks: + hook(timing_dict) print('') diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index dde4dbac..91b65540 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -24,6 +24,7 @@ export const defaultJobConfig: JobConfig = { sqlite_db_path: './aitk_db.db', device: 'cuda:0', trigger_word: null, + performance_log_every: 10, network: { type: 'lora', linear: 16, diff --git a/ui/src/components/JobOverview.tsx b/ui/src/components/JobOverview.tsx index ba1ad5c5..2580374c 100644 --- a/ui/src/components/JobOverview.tsx +++ b/ui/src/components/JobOverview.tsx @@ -2,8 +2,8 @@ import { Job } from '@prisma/client'; import useGPUInfo from '@/hooks/useGPUInfo'; import GPUWidget from '@/components/GPUWidget'; import FilesWidget from '@/components/FilesWidget'; -import { getJobConfig, getTotalSteps } from '@/utils/jobs'; -import { Cpu, HardDrive, Info } from 'lucide-react'; +import { getTotalSteps } from '@/utils/jobs'; +import { Cpu, HardDrive, Info, Gauge } from 'lucide-react'; import { useMemo } from 'react'; interface JobOverviewProps { @@ -45,7 +45,7 @@ export default function JobOverview({ job }: JobOverviewProps) { {/* Job Information Panel */}
-

Job Details

+

{job.info}

{job.status}
@@ -64,7 +64,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
{/* Job Info Grid */} -
+
@@ -82,10 +82,10 @@ export default function JobOverview({ job }: JobOverviewProps) {
- +
-

Additional Information

-

{job.info}

+

Speed

+

{job.speed_string == "" ? "?" : job.speed_string}

diff --git a/ui/src/types.ts b/ui/src/types.ts index 5d334128..42c737cd 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -132,6 +132,7 @@ export interface ProcessConfig { type: 'ui_trainer'; sqlite_db_path?: string; training_folder: string; + performance_log_every: number; trigger_word: string | null; device: string; network?: NetworkConfig;