mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add loss graph to the ui
This commit is contained in:
@@ -127,7 +127,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.has_first_sample_requested = False
|
||||
self.first_sample_config = self.sample_config
|
||||
self.logging_config = LoggingConfig(**self.get_conf('logging', {}))
|
||||
self.logger = create_logger(self.logging_config, config)
|
||||
self.logger = create_logger(self.logging_config, config, self.save_root)
|
||||
self.optimizer: torch.optim.Optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
@@ -2308,7 +2308,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# commit log
|
||||
if self.accelerator.is_main_process:
|
||||
self.logger.commit(step=self.step_num)
|
||||
with self.timer('commit_logger'):
|
||||
self.logger.commit(step=self.step_num)
|
||||
|
||||
# sets progress bar to match out step
|
||||
if self.progress_bar is not None:
|
||||
|
||||
@@ -35,6 +35,7 @@ class LoggingConfig:
|
||||
self.log_every: int = kwargs.get('log_every', 100)
|
||||
self.verbose: bool = kwargs.get('verbose', False)
|
||||
self.use_wandb: bool = kwargs.get('use_wandb', False)
|
||||
self.use_ui_logger: bool = kwargs.get('use_ui_logger', False)
|
||||
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
|
||||
self.run_name: str = kwargs.get('run_name', None)
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ from typing import OrderedDict, Optional
|
||||
from PIL import Image
|
||||
|
||||
from toolkit.config_modules import LoggingConfig
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
|
||||
# Base logger class
|
||||
# This class does nothing, it's just a placeholder
|
||||
@@ -12,11 +17,11 @@ class EmptyLogger:
|
||||
# start logging the training
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
|
||||
# collect the log to send
|
||||
def log(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
# send the log
|
||||
def commit(self, step: Optional[int] = None):
|
||||
pass
|
||||
@@ -29,6 +34,7 @@ class EmptyLogger:
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
|
||||
# Wandb logger class
|
||||
# This class logs the data to wandb
|
||||
class WandbLogger(EmptyLogger):
|
||||
@@ -41,13 +47,15 @@ class WandbLogger(EmptyLogger):
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`")
|
||||
|
||||
raise ImportError(
|
||||
"Failed to import wandb. Please install wandb by running `pip install wandb`"
|
||||
)
|
||||
|
||||
# send the whole config to wandb
|
||||
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
|
||||
self.run = run
|
||||
self._log = wandb.log # log function
|
||||
self._image = wandb.Image # image object
|
||||
self._log = wandb.log # log function
|
||||
self._image = wandb.Image # image object
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
# when commit is False, wandb increments the step,
|
||||
@@ -74,11 +82,235 @@ class WandbLogger(EmptyLogger):
|
||||
def finish(self):
|
||||
self.run.finish()
|
||||
|
||||
|
||||
class UILogger:
|
||||
def __init__(
|
||||
self,
|
||||
log_file: str,
|
||||
flush_every_n: int = 256,
|
||||
flush_every_secs: float = 0.25,
|
||||
) -> None:
|
||||
self.log_file = log_file
|
||||
self._log_to_commit: Dict[str, Any] = {}
|
||||
|
||||
self._con: Optional[sqlite3.Connection] = None
|
||||
self._started = False
|
||||
|
||||
self._step_counter = 0
|
||||
|
||||
# buffered writes
|
||||
self._pending_steps: List[Tuple[int, float]] = []
|
||||
self._pending_metrics: List[
|
||||
Tuple[int, str, Optional[float], Optional[str]]
|
||||
] = []
|
||||
self._pending_key_minmax: Dict[str, Tuple[int, int]] = {}
|
||||
|
||||
self._flush_every_n = int(flush_every_n)
|
||||
self._flush_every_secs = float(flush_every_secs)
|
||||
self._last_flush = time.time()
|
||||
|
||||
# start logging the training
|
||||
def start(self):
|
||||
if self._started:
|
||||
return
|
||||
|
||||
parent = os.path.dirname(os.path.abspath(self.log_file))
|
||||
if parent and not os.path.exists(parent):
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
self._con = sqlite3.connect(self.log_file, timeout=30.0, isolation_level=None)
|
||||
self._con.execute("PRAGMA journal_mode=WAL;")
|
||||
self._con.execute("PRAGMA synchronous=NORMAL;")
|
||||
self._con.execute("PRAGMA temp_store=MEMORY;")
|
||||
self._con.execute("PRAGMA foreign_keys=ON;")
|
||||
self._con.execute("PRAGMA busy_timeout=30000;")
|
||||
|
||||
self._init_schema(self._con)
|
||||
|
||||
self._started = True
|
||||
self._last_flush = time.time()
|
||||
|
||||
# collect the log to send
|
||||
def log(self, log_dict):
|
||||
# log_dict is like {'learning_rate': learning_rate}
|
||||
if not isinstance(log_dict, dict):
|
||||
raise TypeError("log_dict must be a dict")
|
||||
self._log_to_commit.update(log_dict)
|
||||
|
||||
# send the log
|
||||
def commit(self, step: Optional[int] = None):
|
||||
if not self._started:
|
||||
self.start()
|
||||
|
||||
if not self._log_to_commit:
|
||||
return
|
||||
|
||||
if step is None:
|
||||
step = self._step_counter
|
||||
self._step_counter += 1
|
||||
else:
|
||||
step = int(step)
|
||||
if step >= self._step_counter:
|
||||
self._step_counter = step + 1
|
||||
|
||||
wall_time = time.time()
|
||||
|
||||
# buffer step row (upsert later)
|
||||
self._pending_steps.append((step, wall_time))
|
||||
|
||||
# buffer metrics rows + key min/max updates
|
||||
for k, v in self._log_to_commit.items():
|
||||
k = k if isinstance(k, str) else str(k)
|
||||
vr, vt = self._coerce_value(v)
|
||||
|
||||
self._pending_metrics.append((step, k, vr, vt))
|
||||
|
||||
if k in self._pending_key_minmax:
|
||||
lo, hi = self._pending_key_minmax[k]
|
||||
if step < lo:
|
||||
lo = step
|
||||
if step > hi:
|
||||
hi = step
|
||||
self._pending_key_minmax[k] = (lo, hi)
|
||||
else:
|
||||
self._pending_key_minmax[k] = (step, step)
|
||||
|
||||
self._log_to_commit = {}
|
||||
|
||||
# flush conditions
|
||||
now = time.time()
|
||||
if (
|
||||
len(self._pending_metrics) >= self._flush_every_n
|
||||
or (now - self._last_flush) >= self._flush_every_secs
|
||||
):
|
||||
self._flush()
|
||||
|
||||
# log image
|
||||
def log_image(self, *args, **kwargs):
|
||||
# this doesnt log images for now
|
||||
pass
|
||||
|
||||
# finish logging
|
||||
def finish(self):
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
self._flush()
|
||||
|
||||
assert self._con is not None
|
||||
self._con.close()
|
||||
self._con = None
|
||||
self._started = False
|
||||
|
||||
# -------------------------
|
||||
# internal
|
||||
# -------------------------
|
||||
|
||||
def _init_schema(self, con: sqlite3.Connection) -> None:
|
||||
con.execute("BEGIN;")
|
||||
|
||||
con.execute("""
|
||||
CREATE TABLE IF NOT EXISTS steps (
|
||||
step INTEGER PRIMARY KEY,
|
||||
wall_time REAL NOT NULL
|
||||
);
|
||||
""")
|
||||
|
||||
con.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metric_keys (
|
||||
key TEXT PRIMARY KEY,
|
||||
first_seen_step INTEGER,
|
||||
last_seen_step INTEGER
|
||||
);
|
||||
""")
|
||||
|
||||
con.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metrics (
|
||||
step INTEGER NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value_real REAL,
|
||||
value_text TEXT,
|
||||
PRIMARY KEY (step, key),
|
||||
FOREIGN KEY (step) REFERENCES steps(step) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
con.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_metrics_key_step ON metrics (key, step);"
|
||||
)
|
||||
|
||||
con.execute("COMMIT;")
|
||||
|
||||
def _coerce_value(self, v: Any) -> Tuple[Optional[float], Optional[str]]:
|
||||
if v is None:
|
||||
return None, None
|
||||
if isinstance(v, bool):
|
||||
return float(int(v)), None
|
||||
if isinstance(v, (int, float)):
|
||||
return float(v), None
|
||||
try:
|
||||
return float(v), None # type: ignore[arg-type]
|
||||
except Exception:
|
||||
return None, str(v)
|
||||
|
||||
def _flush(self) -> None:
|
||||
if not self._pending_steps and not self._pending_metrics:
|
||||
return
|
||||
|
||||
assert self._con is not None
|
||||
con = self._con
|
||||
|
||||
con.execute("BEGIN;")
|
||||
|
||||
# steps upsert
|
||||
if self._pending_steps:
|
||||
con.executemany(
|
||||
"INSERT INTO steps(step, wall_time) VALUES(?, ?) "
|
||||
"ON CONFLICT(step) DO UPDATE SET wall_time=excluded.wall_time;",
|
||||
self._pending_steps,
|
||||
)
|
||||
|
||||
# keys table upsert (maintains list of keys + seen range)
|
||||
if self._pending_key_minmax:
|
||||
con.executemany(
|
||||
"INSERT INTO metric_keys(key, first_seen_step, last_seen_step) VALUES(?, ?, ?) "
|
||||
"ON CONFLICT(key) DO UPDATE SET "
|
||||
"first_seen_step=MIN(metric_keys.first_seen_step, excluded.first_seen_step), "
|
||||
"last_seen_step=MAX(metric_keys.last_seen_step, excluded.last_seen_step);",
|
||||
[(k, lo, hi) for k, (lo, hi) in self._pending_key_minmax.items()],
|
||||
)
|
||||
|
||||
# metrics upsert
|
||||
if self._pending_metrics:
|
||||
con.executemany(
|
||||
"INSERT INTO metrics(step, key, value_real, value_text) VALUES(?, ?, ?, ?) "
|
||||
"ON CONFLICT(step, key) DO UPDATE SET "
|
||||
"value_real=excluded.value_real, value_text=excluded.value_text;",
|
||||
self._pending_metrics,
|
||||
)
|
||||
|
||||
con.execute("COMMIT;")
|
||||
|
||||
self._pending_steps.clear()
|
||||
self._pending_metrics.clear()
|
||||
self._pending_key_minmax.clear()
|
||||
self._last_flush = time.time()
|
||||
|
||||
|
||||
# create logger based on the logging config
|
||||
def create_logger(logging_config: LoggingConfig, all_config: OrderedDict):
|
||||
def create_logger(
|
||||
logging_config: LoggingConfig,
|
||||
all_config: OrderedDict,
|
||||
save_root: Optional[str] = None,
|
||||
):
|
||||
if logging_config.use_wandb:
|
||||
project_name = logging_config.project_name
|
||||
run_name = logging_config.run_name
|
||||
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
|
||||
elif logging_config.use_ui_logger:
|
||||
if save_root is None:
|
||||
raise ValueError("save_root must be provided when using UILogger")
|
||||
log_file = os.path.join(save_root, "loss_log.db")
|
||||
return UILogger(log_file=log_file)
|
||||
else:
|
||||
return EmptyLogger()
|
||||
|
||||
884
ui/package-lock.json
generated
884
ui/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@
|
||||
"axios": "^1.7.9",
|
||||
"classnames": "^2.5.1",
|
||||
"lucide-react": "^0.475.0",
|
||||
"next": "15.1.11",
|
||||
"next": "^15.5.9",
|
||||
"node-cache": "^5.1.2",
|
||||
"prisma": "^6.3.1",
|
||||
"react": "^19.0.0",
|
||||
@@ -28,6 +28,7 @@
|
||||
"react-global-hooks": "^1.3.5",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-select": "^5.10.1",
|
||||
"recharts": "^3.6.0",
|
||||
"sqlite3": "^5.1.7",
|
||||
"systeminformation": "^5.27.11",
|
||||
"uuid": "^11.1.0",
|
||||
|
||||
98
ui/src/app/api/jobs/[jobID]/loss/route.ts
Normal file
98
ui/src/app/api/jobs/[jobID]/loss/route.ts
Normal file
@@ -0,0 +1,98 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import { getTrainingFolder } from '@/server/settings';
|
||||
|
||||
import sqlite3 from 'sqlite3';
|
||||
|
||||
export const runtime = 'nodejs';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
function openDb(filename: string) {
|
||||
const db = new sqlite3.Database(filename);
|
||||
db.configure('busyTimeout', 30_000);
|
||||
return db;
|
||||
}
|
||||
|
||||
function all<T = any>(db: sqlite3.Database, sql: string, params: any[] = []) {
|
||||
return new Promise<T[]>((resolve, reject) => {
|
||||
db.all(sql, params, (err, rows) => {
|
||||
if (err) reject(err);
|
||||
else resolve(rows as T[]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function closeDb(db: sqlite3.Database) {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
db.close((err) => (err ? reject(err) : resolve()));
|
||||
});
|
||||
}
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
// this must be awaited to avoid TS error
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({ where: { id: jobID } });
|
||||
if (!job) return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
|
||||
const trainingFolder = await getTrainingFolder();
|
||||
const jobFolder = path.join(trainingFolder, job.name);
|
||||
const logPath = path.join(jobFolder, 'loss_log.db');
|
||||
|
||||
if (!fs.existsSync(logPath)) {
|
||||
return NextResponse.json({ keys: [], key: 'loss', points: [] });
|
||||
}
|
||||
|
||||
const url = new URL(request.url);
|
||||
const key = url.searchParams.get('key') ?? 'loss';
|
||||
const limit = Math.min(Number(url.searchParams.get('limit') ?? 2000), 20000);
|
||||
const sinceStepParam = url.searchParams.get('since_step');
|
||||
const sinceStep = sinceStepParam != null ? Number(sinceStepParam) : null;
|
||||
const stride = Math.max(1, Number(url.searchParams.get('stride') ?? 1));
|
||||
|
||||
const db = openDb(logPath);
|
||||
|
||||
try {
|
||||
const keysRows = await all<{ key: string }>(db, `SELECT key FROM metric_keys ORDER BY key ASC`);
|
||||
const keys = keysRows.map((r) => r.key);
|
||||
|
||||
const points = await all<{
|
||||
step: number;
|
||||
wall_time: number;
|
||||
value: number | null;
|
||||
value_text: string | null;
|
||||
}>(
|
||||
db,
|
||||
`
|
||||
SELECT
|
||||
m.step AS step,
|
||||
s.wall_time AS wall_time,
|
||||
m.value_real AS value,
|
||||
m.value_text AS value_text
|
||||
FROM metrics m
|
||||
JOIN steps s ON s.step = m.step
|
||||
WHERE m.key = ?
|
||||
AND (? IS NULL OR m.step > ?)
|
||||
AND (m.step % ?) = 0
|
||||
ORDER BY m.step ASC
|
||||
LIMIT ?
|
||||
`,
|
||||
[key, sinceStep, sinceStep, stride, limit]
|
||||
);
|
||||
|
||||
return NextResponse.json({
|
||||
key,
|
||||
keys,
|
||||
points: points.map((p) => ({
|
||||
step: p.step,
|
||||
wall_time: p.wall_time,
|
||||
value: p.value ?? (p.value_text ? Number(p.value_text) : null),
|
||||
})),
|
||||
});
|
||||
} finally {
|
||||
await closeDb(db);
|
||||
}
|
||||
}
|
||||
@@ -10,9 +10,10 @@ import JobOverview from '@/components/JobOverview';
|
||||
import { redirect } from 'next/navigation';
|
||||
import JobActionBar from '@/components/JobActionBar';
|
||||
import JobConfigViewer from '@/components/JobConfigViewer';
|
||||
import JobLossGraph from '@/components/JobLossGraph';
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
type PageKey = 'overview' | 'samples' | 'config';
|
||||
type PageKey = 'overview' | 'samples' | 'config' | 'loss_log';
|
||||
|
||||
interface Page {
|
||||
name: string;
|
||||
@@ -36,6 +37,12 @@ const pages: Page[] = [
|
||||
menuItem: SampleImagesMenu,
|
||||
mainCss: 'pt-24',
|
||||
},
|
||||
{
|
||||
name: 'Loss Graph',
|
||||
value: 'loss_log',
|
||||
component: JobLossGraph,
|
||||
mainCss: 'pt-24',
|
||||
},
|
||||
{
|
||||
name: 'Config File',
|
||||
value: 'config',
|
||||
|
||||
@@ -92,6 +92,10 @@ export const defaultJobConfig: JobConfig = {
|
||||
switch_boundary_every: 1,
|
||||
loss_type: 'mse',
|
||||
},
|
||||
logging: {
|
||||
log_every: 1,
|
||||
use_ui_logger: true,
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
quantize: true,
|
||||
@@ -187,5 +191,13 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
|
||||
false) as boolean;
|
||||
delete jobConfig.config.process[0].model.auto_memory;
|
||||
}
|
||||
|
||||
if (!('logging' in jobConfig.config.process[0])) {
|
||||
//@ts-ignore
|
||||
jobConfig.config.process[0].logging = {
|
||||
log_every: 1,
|
||||
use_ui_logger: true,
|
||||
};
|
||||
}
|
||||
return jobConfig;
|
||||
};
|
||||
|
||||
432
ui/src/components/JobLossGraph.tsx
Normal file
432
ui/src/components/JobLossGraph.tsx
Normal file
@@ -0,0 +1,432 @@
|
||||
'use client';
|
||||
|
||||
import { Job } from '@prisma/client';
|
||||
import useJobLossLog, { LossPoint } from '@/hooks/useJobLossLog';
|
||||
import { useMemo, useState, useEffect } from 'react';
|
||||
import { ResponsiveContainer, LineChart, Line, XAxis, YAxis, Tooltip, CartesianGrid, Legend } from 'recharts';
|
||||
|
||||
interface Props {
|
||||
job: Job;
|
||||
}
|
||||
|
||||
function formatNum(v: number) {
|
||||
if (!Number.isFinite(v)) return '';
|
||||
if (Math.abs(v) >= 1000) return v.toFixed(0);
|
||||
if (Math.abs(v) >= 10) return v.toFixed(3);
|
||||
if (Math.abs(v) >= 1) return v.toFixed(4);
|
||||
return v.toPrecision(4);
|
||||
}
|
||||
|
||||
function clamp01(x: number) {
|
||||
return Math.max(0, Math.min(1, x));
|
||||
}
|
||||
|
||||
// EMA smoothing that works on a per-series list.
|
||||
// alpha=1 -> no smoothing, alpha closer to 0 -> more smoothing.
|
||||
function emaSmoothPoints(points: { step: number; value: number }[], alpha: number) {
|
||||
if (points.length === 0) return [];
|
||||
const a = clamp01(alpha);
|
||||
const out: { step: number; value: number }[] = new Array(points.length);
|
||||
|
||||
let prev = points[0].value;
|
||||
out[0] = { step: points[0].step, value: prev };
|
||||
|
||||
for (let i = 1; i < points.length; i++) {
|
||||
const x = points[i].value;
|
||||
prev = a * x + (1 - a) * prev;
|
||||
out[i] = { step: points[i].step, value: prev };
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
function hashToIndex(str: string, mod: number) {
|
||||
let h = 2166136261;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
h ^= str.charCodeAt(i);
|
||||
h = Math.imul(h, 16777619);
|
||||
}
|
||||
return Math.abs(h) % mod;
|
||||
}
|
||||
|
||||
const PALETTE = [
|
||||
'rgba(96,165,250,1)', // blue-400
|
||||
'rgba(52,211,153,1)', // emerald-400
|
||||
'rgba(167,139,250,1)', // purple-400
|
||||
'rgba(251,191,36,1)', // amber-400
|
||||
'rgba(244,114,182,1)', // pink-400
|
||||
'rgba(248,113,113,1)', // red-400
|
||||
'rgba(34,211,238,1)', // cyan-400
|
||||
'rgba(129,140,248,1)', // indigo-400
|
||||
];
|
||||
|
||||
function strokeForKey(key: string) {
|
||||
return PALETTE[hashToIndex(key, PALETTE.length)];
|
||||
}
|
||||
|
||||
export default function JobLossGraph({ job }: Props) {
|
||||
const { series, lossKeys, status, refreshLoss } = useJobLossLog(job.id, 2000);
|
||||
|
||||
// Controls
|
||||
const [useLogScale, setUseLogScale] = useState(false);
|
||||
const [showRaw, setShowRaw] = useState(false);
|
||||
const [showSmoothed, setShowSmoothed] = useState(true);
|
||||
|
||||
// 0..100 slider. 100 = no smoothing, 0 = heavy smoothing.
|
||||
const [smoothing, setSmoothing] = useState(90);
|
||||
|
||||
// UI-only downsample for rendering speed
|
||||
const [plotStride, setPlotStride] = useState(1);
|
||||
|
||||
// show only last N points in the chart (0 = all)
|
||||
const [windowSize, setWindowSize] = useState<number>(4000);
|
||||
|
||||
// quick y clipping for readability
|
||||
const [clipOutliers, setClipOutliers] = useState(false);
|
||||
|
||||
// which loss series are enabled (default: all enabled)
|
||||
const [enabled, setEnabled] = useState<Record<string, boolean>>({});
|
||||
|
||||
// keep enabled map in sync with discovered keys (enable new ones automatically)
|
||||
useEffect(() => {
|
||||
setEnabled(prev => {
|
||||
const next = { ...prev };
|
||||
for (const k of lossKeys) {
|
||||
if (next[k] === undefined) next[k] = true;
|
||||
}
|
||||
// drop removed keys
|
||||
for (const k of Object.keys(next)) {
|
||||
if (!lossKeys.includes(k)) delete next[k];
|
||||
}
|
||||
return next;
|
||||
});
|
||||
}, [lossKeys]);
|
||||
|
||||
const activeKeys = useMemo(() => lossKeys.filter(k => enabled[k] !== false), [lossKeys, enabled]);
|
||||
|
||||
const perSeries = useMemo(() => {
|
||||
// Build per-series processed point arrays (raw + smoothed), then merge by step for charting.
|
||||
const stride = Math.max(1, plotStride | 0);
|
||||
|
||||
// smoothing%: 0 => no smoothing (alpha=1.0), 100 => heavy smoothing (alpha=0.02)
|
||||
const t = clamp01(smoothing / 100);
|
||||
const alpha = 1.0 - t * 0.98; // 1.0 -> 0.02
|
||||
|
||||
const out: Record<string, { raw: { step: number; value: number }[]; smooth: { step: number; value: number }[] }> =
|
||||
{};
|
||||
|
||||
for (const key of activeKeys) {
|
||||
const pts: LossPoint[] = series[key] ?? [];
|
||||
|
||||
let raw = pts
|
||||
.filter(p => p.value !== null && Number.isFinite(p.value as number))
|
||||
.map(p => ({ step: p.step, value: p.value as number }))
|
||||
.filter(p => (useLogScale ? p.value > 0 : true))
|
||||
.filter((_, idx) => idx % stride === 0);
|
||||
|
||||
// windowing (applies after stride)
|
||||
if (windowSize > 0 && raw.length > windowSize) {
|
||||
raw = raw.slice(raw.length - windowSize);
|
||||
}
|
||||
|
||||
const smooth = emaSmoothPoints(raw, alpha);
|
||||
|
||||
out[key] = { raw, smooth };
|
||||
}
|
||||
|
||||
return out;
|
||||
}, [series, activeKeys, smoothing, plotStride, windowSize, useLogScale]);
|
||||
|
||||
const chartData = useMemo(() => {
|
||||
// Merge series into one array of objects keyed by step.
|
||||
// Fields: `${key}__raw` and `${key}__smooth`
|
||||
const map = new Map<number, any>();
|
||||
|
||||
for (const key of activeKeys) {
|
||||
const s = perSeries[key];
|
||||
if (!s) continue;
|
||||
|
||||
for (const p of s.raw) {
|
||||
const row = map.get(p.step) ?? { step: p.step };
|
||||
row[`${key}__raw`] = p.value;
|
||||
map.set(p.step, row);
|
||||
}
|
||||
for (const p of s.smooth) {
|
||||
const row = map.get(p.step) ?? { step: p.step };
|
||||
row[`${key}__smooth`] = p.value;
|
||||
map.set(p.step, row);
|
||||
}
|
||||
}
|
||||
|
||||
const arr = Array.from(map.values());
|
||||
arr.sort((a, b) => a.step - b.step);
|
||||
return arr;
|
||||
}, [activeKeys, perSeries]);
|
||||
|
||||
const hasData = chartData.length > 1;
|
||||
|
||||
const yDomain = useMemo((): [number | 'auto', number | 'auto'] => {
|
||||
if (!clipOutliers || chartData.length < 10) return ['auto', 'auto'];
|
||||
|
||||
// Collect visible values (prefer smoothed if shown, else raw)
|
||||
const vals: number[] = [];
|
||||
for (const row of chartData) {
|
||||
for (const key of activeKeys) {
|
||||
const k = showSmoothed ? `${key}__smooth` : `${key}__raw`;
|
||||
const v = row[k];
|
||||
if (typeof v === 'number' && Number.isFinite(v)) vals.push(v);
|
||||
}
|
||||
}
|
||||
if (vals.length < 10) return ['auto', 'auto'];
|
||||
|
||||
vals.sort((a, b) => a - b);
|
||||
const lo = vals[Math.floor(vals.length * 0.02)];
|
||||
const hi = vals[Math.ceil(vals.length * 0.98) - 1];
|
||||
|
||||
if (!Number.isFinite(lo) || !Number.isFinite(hi) || lo === hi) return ['auto', 'auto'];
|
||||
return [lo, hi];
|
||||
}, [clipOutliers, chartData, activeKeys, showSmoothed]);
|
||||
|
||||
const latestSummary = useMemo(() => {
|
||||
// Provide a simple “latest” readout for the first active series
|
||||
const firstKey = activeKeys[0];
|
||||
if (!firstKey) return null;
|
||||
|
||||
const s = perSeries[firstKey];
|
||||
if (!s) return null;
|
||||
|
||||
const lastRaw = s.raw.length ? s.raw[s.raw.length - 1] : null;
|
||||
const lastSmooth = s.smooth.length ? s.smooth[s.smooth.length - 1] : null;
|
||||
|
||||
return {
|
||||
key: firstKey,
|
||||
step: lastRaw?.step ?? lastSmooth?.step ?? null,
|
||||
raw: lastRaw?.value ?? null,
|
||||
smooth: lastSmooth?.value ?? null,
|
||||
};
|
||||
}, [activeKeys, perSeries]);
|
||||
|
||||
return (
|
||||
<div className="bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800 flex flex-col">
|
||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="h-2 w-2 rounded-full bg-blue-400" />
|
||||
<h2 className="text-gray-100 text-sm font-medium">Loss graph</h2>
|
||||
<span className="text-xs text-gray-400">
|
||||
{status === 'loading' && 'Loading...'}
|
||||
{status === 'refreshing' && 'Refreshing...'}
|
||||
{status === 'error' && 'Error'}
|
||||
{status === 'success' && hasData && `${chartData.length.toLocaleString()} steps`}
|
||||
{status === 'success' && !hasData && 'No data yet'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
onClick={refreshLoss}
|
||||
className="px-3 py-1 rounded-md text-xs bg-gray-700/60 hover:bg-gray-700 text-gray-200 border border-gray-700"
|
||||
>
|
||||
Refresh
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Chart */}
|
||||
<div className="px-4 pt-4 pb-4">
|
||||
<div className="bg-gray-950 rounded-lg border border-gray-800 h-96 relative">
|
||||
{!hasData ? (
|
||||
<div className="h-full w-full flex items-center justify-center text-sm text-gray-400">
|
||||
{status === 'error' ? 'Failed to load loss logs.' : 'Waiting for loss points...'}
|
||||
</div>
|
||||
) : (
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<LineChart data={chartData} margin={{ top: 10, right: 16, bottom: 10, left: 8 }}>
|
||||
<CartesianGrid strokeDasharray="3 3" stroke="rgba(255,255,255,0.06)" />
|
||||
<XAxis
|
||||
dataKey="step"
|
||||
tick={{ fill: 'rgba(255,255,255,0.55)', fontSize: 12 }}
|
||||
tickLine={{ stroke: 'rgba(255,255,255,0.15)' }}
|
||||
axisLine={{ stroke: 'rgba(255,255,255,0.15)' }}
|
||||
minTickGap={40}
|
||||
/>
|
||||
<YAxis
|
||||
scale={useLogScale ? 'log' : 'linear'}
|
||||
tick={{ fill: 'rgba(255,255,255,0.55)', fontSize: 12 }}
|
||||
tickLine={{ stroke: 'rgba(255,255,255,0.15)' }}
|
||||
axisLine={{ stroke: 'rgba(255,255,255,0.15)' }}
|
||||
width={72}
|
||||
tickFormatter={formatNum}
|
||||
domain={yDomain}
|
||||
allowDataOverflow={clipOutliers}
|
||||
/>
|
||||
<Tooltip
|
||||
cursor={{ stroke: 'rgba(59,130,246,0.25)', strokeWidth: 1 }}
|
||||
contentStyle={{
|
||||
background: 'rgba(17,24,39,0.96)',
|
||||
border: '1px solid rgba(31,41,55,1)',
|
||||
borderRadius: 10,
|
||||
color: 'rgba(255,255,255,0.9)',
|
||||
fontSize: 12,
|
||||
}}
|
||||
labelStyle={{ color: 'rgba(255,255,255,0.75)' }}
|
||||
labelFormatter={(label: any) => `step ${label}`}
|
||||
formatter={(value: any, name: any) => [formatNum(Number(value)), name]}
|
||||
/>
|
||||
|
||||
<Legend
|
||||
wrapperStyle={{
|
||||
paddingTop: 8,
|
||||
color: 'rgba(255,255,255,0.7)',
|
||||
fontSize: 12,
|
||||
}}
|
||||
/>
|
||||
|
||||
{activeKeys.map(k => {
|
||||
const color = strokeForKey(k);
|
||||
|
||||
return (
|
||||
<g key={k}>
|
||||
{showRaw && (
|
||||
<Line
|
||||
type="monotone"
|
||||
dataKey={`${k}__raw`}
|
||||
name={`${k} (raw)`}
|
||||
stroke={color.replace('1)', '0.40)')}
|
||||
strokeWidth={1.25}
|
||||
dot={false}
|
||||
isAnimationActive={false}
|
||||
/>
|
||||
)}
|
||||
{showSmoothed && (
|
||||
<Line
|
||||
type="monotone"
|
||||
dataKey={`${k}__smooth`}
|
||||
name={`${k}`}
|
||||
stroke={color}
|
||||
strokeWidth={2}
|
||||
dot={false}
|
||||
isAnimationActive={false}
|
||||
/>
|
||||
)}
|
||||
</g>
|
||||
);
|
||||
})}
|
||||
</LineChart>
|
||||
</ResponsiveContainer>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Controls */}
|
||||
<div className="px-4 pb-2">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-3">
|
||||
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3">
|
||||
<label className="block text-xs text-gray-400 mb-2">Display</label>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<ToggleButton checked={showSmoothed} onClick={() => setShowSmoothed(v => !v)} label="Smoothed" />
|
||||
<ToggleButton checked={showRaw} onClick={() => setShowRaw(v => !v)} label="Raw" />
|
||||
<ToggleButton checked={useLogScale} onClick={() => setUseLogScale(v => !v)} label="Log Y" />
|
||||
<ToggleButton checked={clipOutliers} onClick={() => setClipOutliers(v => !v)} label="Clip outliers" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3">
|
||||
<label className="block text-xs text-gray-400 mb-2">Series</label>
|
||||
{lossKeys.length === 0 ? (
|
||||
<div className="text-sm text-gray-400">No loss keys found yet.</div>
|
||||
) : (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{lossKeys.map(k => (
|
||||
<button
|
||||
key={k}
|
||||
type="button"
|
||||
onClick={() => setEnabled(prev => ({ ...prev, [k]: !(prev[k] ?? true) }))}
|
||||
className={[
|
||||
'px-3 py-1 rounded-md text-xs border transition-colors',
|
||||
enabled[k] === false
|
||||
? 'bg-gray-900 text-gray-400 border-gray-800 hover:bg-gray-800/60'
|
||||
: 'bg-gray-900 text-gray-200 border-gray-800 hover:bg-gray-800/60',
|
||||
].join(' ')}
|
||||
aria-pressed={enabled[k] !== false}
|
||||
title={k}
|
||||
>
|
||||
<span className="inline-block h-2 w-2 rounded-full mr-2" style={{ background: strokeForKey(k) }} />
|
||||
{k}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3">
|
||||
<div className="flex items-center justify-between mb-1">
|
||||
<label className="block text-xs text-gray-400">Smoothing</label>
|
||||
<span className="text-xs text-gray-300">{smoothing}%</span>
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min={0}
|
||||
max={100}
|
||||
value={smoothing}
|
||||
onChange={e => setSmoothing(Number(e.target.value))}
|
||||
className="w-full accent-blue-500"
|
||||
disabled={!showSmoothed}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3">
|
||||
<div className="flex items-center justify-between mb-1">
|
||||
<label className="block text-xs text-gray-400">Plot stride</label>
|
||||
<span className="text-xs text-gray-300">every {plotStride} pt</span>
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min={1}
|
||||
max={20}
|
||||
value={plotStride}
|
||||
onChange={e => setPlotStride(Number(e.target.value))}
|
||||
className="w-full accent-blue-500"
|
||||
/>
|
||||
<div className="mt-2 text-[11px] text-gray-500">UI downsample for huge runs.</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3 md:col-span-2">
|
||||
<div className="flex items-center justify-between mb-1">
|
||||
<label className="block text-xs text-gray-400">Window (last N points)</label>
|
||||
<span className="text-xs text-gray-300">{windowSize === 0 ? 'all' : windowSize.toLocaleString()}</span>
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min={0}
|
||||
max={20000}
|
||||
step={250}
|
||||
value={windowSize}
|
||||
onChange={e => setWindowSize(Number(e.target.value))}
|
||||
className="w-full accent-blue-500"
|
||||
/>
|
||||
<div className="mt-2 text-[11px] text-gray-500">
|
||||
Set to 0 to show all (not recommended for very long runs).
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ToggleButton({ checked, onClick, label }: { checked: boolean; onClick: () => void; label: string }) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClick}
|
||||
className={[
|
||||
'px-3 py-1 rounded-md text-xs border transition-colors',
|
||||
checked
|
||||
? 'bg-blue-500/10 text-blue-300 border-blue-500/30 hover:bg-blue-500/15'
|
||||
: 'bg-gray-900 text-gray-300 border-gray-800 hover:bg-gray-800/60',
|
||||
].join(' ')}
|
||||
aria-pressed={checked}
|
||||
>
|
||||
{label}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
145
ui/src/hooks/useJobLossLog.tsx
Normal file
145
ui/src/hooks/useJobLossLog.tsx
Normal file
@@ -0,0 +1,145 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState, useRef, useCallback, useMemo } from 'react';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export interface LossPoint {
|
||||
step: number;
|
||||
wall_time?: number;
|
||||
value: number | null;
|
||||
}
|
||||
|
||||
type SeriesMap = Record<string, LossPoint[]>;
|
||||
|
||||
function isLossKey(key: string) {
|
||||
// treat anything containing "loss" as a loss-series
|
||||
// (covers loss, train_loss, val_loss, loss/xyz, etc.)
|
||||
return /loss/i.test(key);
|
||||
}
|
||||
|
||||
export default function useJobLossLog(jobID: string, reloadInterval: null | number = null) {
|
||||
const [series, setSeries] = useState<SeriesMap>({});
|
||||
const [keys, setKeys] = useState<string[]>([]);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle');
|
||||
|
||||
const didInitialLoadRef = useRef(false);
|
||||
const inFlightRef = useRef(false);
|
||||
|
||||
// track last step per key so polling is incremental per series
|
||||
const lastStepByKeyRef = useRef<Record<string, number | null>>({});
|
||||
|
||||
const lossKeys = useMemo(() => {
|
||||
const base = (keys ?? []).filter(isLossKey);
|
||||
// if keys table is empty early on, fall back to just "loss"
|
||||
if (base.length === 0) return ['loss'];
|
||||
return base.sort();
|
||||
}, [keys]);
|
||||
|
||||
const refreshLoss = useCallback(async () => {
|
||||
if (!jobID) return;
|
||||
|
||||
if (inFlightRef.current) return;
|
||||
inFlightRef.current = true;
|
||||
|
||||
const loadStatus: 'loading' | 'refreshing' = didInitialLoadRef.current ? 'refreshing' : 'loading';
|
||||
setStatus(loadStatus);
|
||||
|
||||
try {
|
||||
// Step 1: get key list (we can do this by calling endpoint once; it returns keys)
|
||||
// Keep it cheap: limit=1.
|
||||
const first = await apiClient
|
||||
.get(`/api/jobs/${jobID}/loss`, { params: { key: 'loss', limit: 1 } })
|
||||
.then(res => res.data as { keys?: string[] });
|
||||
|
||||
const newKeys = first.keys ?? [];
|
||||
setKeys(newKeys);
|
||||
|
||||
const wantedLossKeys = (newKeys.filter(isLossKey).length ? newKeys.filter(isLossKey) : ['loss']).sort();
|
||||
|
||||
// Step 2: fetch each loss key incrementally (since_step per key if polling)
|
||||
const requests = wantedLossKeys.map(k => {
|
||||
const params: Record<string, any> = { key: k };
|
||||
|
||||
if (reloadInterval && lastStepByKeyRef.current[k] != null) {
|
||||
params.since_step = lastStepByKeyRef.current[k];
|
||||
}
|
||||
|
||||
// keep default limit from server (or set explicitly if you want)
|
||||
// params.limit = 2000;
|
||||
|
||||
return apiClient
|
||||
.get(`/api/jobs/${jobID}/loss`, { params })
|
||||
.then(res => res.data as { key: string; points?: LossPoint[] });
|
||||
});
|
||||
|
||||
const results = await Promise.all(requests);
|
||||
|
||||
setSeries(prev => {
|
||||
const next: SeriesMap = { ...prev };
|
||||
|
||||
for (const r of results) {
|
||||
const k = r.key;
|
||||
const newPoints = (r.points ?? []).filter(p => p.value !== null);
|
||||
|
||||
if (!didInitialLoadRef.current) {
|
||||
// initial: replace
|
||||
next[k] = newPoints;
|
||||
} else if (newPoints.length) {
|
||||
const existing = next[k] ?? [];
|
||||
const prevLast = existing.length ? existing[existing.length - 1].step : null;
|
||||
const filtered = prevLast == null ? newPoints : newPoints.filter(p => p.step > prevLast);
|
||||
next[k] = filtered.length ? [...existing, ...filtered] : existing;
|
||||
} else {
|
||||
// no new points: keep existing
|
||||
next[k] = next[k] ?? [];
|
||||
}
|
||||
|
||||
// update last step per key
|
||||
const finalArr = next[k] ?? [];
|
||||
lastStepByKeyRef.current[k] = finalArr.length
|
||||
? finalArr[finalArr.length - 1].step
|
||||
: (lastStepByKeyRef.current[k] ?? null);
|
||||
}
|
||||
|
||||
// remove stale loss keys that no longer exist (rare, but keeps UI clean)
|
||||
for (const existingKey of Object.keys(next)) {
|
||||
if (isLossKey(existingKey) && !wantedLossKeys.includes(existingKey)) {
|
||||
delete next[existingKey];
|
||||
delete lastStepByKeyRef.current[existingKey];
|
||||
}
|
||||
}
|
||||
|
||||
return next;
|
||||
});
|
||||
|
||||
setStatus('success');
|
||||
didInitialLoadRef.current = true;
|
||||
} catch (err) {
|
||||
console.error('Error fetching loss logs:', err);
|
||||
setStatus('error');
|
||||
} finally {
|
||||
inFlightRef.current = false;
|
||||
}
|
||||
}, [jobID, reloadInterval]);
|
||||
|
||||
useEffect(() => {
|
||||
// reset when job changes
|
||||
didInitialLoadRef.current = false;
|
||||
lastStepByKeyRef.current = {};
|
||||
setSeries({});
|
||||
setKeys([]);
|
||||
setStatus('idle');
|
||||
|
||||
refreshLoss();
|
||||
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
refreshLoss();
|
||||
}, reloadInterval);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}
|
||||
}, [jobID, reloadInterval, refreshLoss]);
|
||||
|
||||
return { series, keys, lossKeys, status, refreshLoss, setSeries };
|
||||
}
|
||||
@@ -197,6 +197,11 @@ export interface SampleConfig {
|
||||
fps: number;
|
||||
}
|
||||
|
||||
export interface LoggingConfig {
|
||||
log_every: number;
|
||||
use_ui_logger: boolean;
|
||||
}
|
||||
|
||||
export interface SliderConfig {
|
||||
guidance_strength?: number;
|
||||
anchor_strength?: number;
|
||||
@@ -218,6 +223,7 @@ export interface ProcessConfig {
|
||||
save: SaveConfig;
|
||||
datasets: DatasetConfig[];
|
||||
train: TrainConfig;
|
||||
logging: LoggingConfig;
|
||||
model: ModelConfig;
|
||||
sample: SampleConfig;
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.7.9"
|
||||
VERSION = "0.7.10"
|
||||
|
||||
Reference in New Issue
Block a user