Add loss graph to the ui

This commit is contained in:
Jaret Burkett
2025-12-18 10:08:59 -07:00
parent 3b6c1ade18
commit ba00eea7d9
12 changed files with 1620 additions and 223 deletions

View File

@@ -127,7 +127,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.has_first_sample_requested = False
self.first_sample_config = self.sample_config
self.logging_config = LoggingConfig(**self.get_conf('logging', {}))
self.logger = create_logger(self.logging_config, config)
self.logger = create_logger(self.logging_config, config, self.save_root)
self.optimizer: torch.optim.Optimizer = None
self.lr_scheduler = None
self.data_loader: Union[DataLoader, None] = None
@@ -2308,7 +2308,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# commit log
if self.accelerator.is_main_process:
self.logger.commit(step=self.step_num)
with self.timer('commit_logger'):
self.logger.commit(step=self.step_num)
# sets progress bar to match out step
if self.progress_bar is not None:

View File

@@ -35,6 +35,7 @@ class LoggingConfig:
self.log_every: int = kwargs.get('log_every', 100)
self.verbose: bool = kwargs.get('verbose', False)
self.use_wandb: bool = kwargs.get('use_wandb', False)
self.use_ui_logger: bool = kwargs.get('use_ui_logger', False)
self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
self.run_name: str = kwargs.get('run_name', None)

View File

@@ -2,6 +2,11 @@ from typing import OrderedDict, Optional
from PIL import Image
from toolkit.config_modules import LoggingConfig
import os
import sqlite3
import time
from typing import Any, Dict, Tuple, List
# Base logger class
# This class does nothing, it's just a placeholder
@@ -12,11 +17,11 @@ class EmptyLogger:
# start logging the training
def start(self):
pass
# collect the log to send
def log(self, *args, **kwargs):
pass
# send the log
def commit(self, step: Optional[int] = None):
pass
@@ -29,6 +34,7 @@ class EmptyLogger:
def finish(self):
pass
# Wandb logger class
# This class logs the data to wandb
class WandbLogger(EmptyLogger):
@@ -41,13 +47,15 @@ class WandbLogger(EmptyLogger):
try:
import wandb
except ImportError:
raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`")
raise ImportError(
"Failed to import wandb. Please install wandb by running `pip install wandb`"
)
# send the whole config to wandb
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
self.run = run
self._log = wandb.log # log function
self._image = wandb.Image # image object
self._log = wandb.log # log function
self._image = wandb.Image # image object
def log(self, *args, **kwargs):
# when commit is False, wandb increments the step,
@@ -74,11 +82,235 @@ class WandbLogger(EmptyLogger):
def finish(self):
self.run.finish()
class UILogger:
def __init__(
self,
log_file: str,
flush_every_n: int = 256,
flush_every_secs: float = 0.25,
) -> None:
self.log_file = log_file
self._log_to_commit: Dict[str, Any] = {}
self._con: Optional[sqlite3.Connection] = None
self._started = False
self._step_counter = 0
# buffered writes
self._pending_steps: List[Tuple[int, float]] = []
self._pending_metrics: List[
Tuple[int, str, Optional[float], Optional[str]]
] = []
self._pending_key_minmax: Dict[str, Tuple[int, int]] = {}
self._flush_every_n = int(flush_every_n)
self._flush_every_secs = float(flush_every_secs)
self._last_flush = time.time()
# start logging the training
def start(self):
if self._started:
return
parent = os.path.dirname(os.path.abspath(self.log_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
self._con = sqlite3.connect(self.log_file, timeout=30.0, isolation_level=None)
self._con.execute("PRAGMA journal_mode=WAL;")
self._con.execute("PRAGMA synchronous=NORMAL;")
self._con.execute("PRAGMA temp_store=MEMORY;")
self._con.execute("PRAGMA foreign_keys=ON;")
self._con.execute("PRAGMA busy_timeout=30000;")
self._init_schema(self._con)
self._started = True
self._last_flush = time.time()
# collect the log to send
def log(self, log_dict):
# log_dict is like {'learning_rate': learning_rate}
if not isinstance(log_dict, dict):
raise TypeError("log_dict must be a dict")
self._log_to_commit.update(log_dict)
# send the log
def commit(self, step: Optional[int] = None):
if not self._started:
self.start()
if not self._log_to_commit:
return
if step is None:
step = self._step_counter
self._step_counter += 1
else:
step = int(step)
if step >= self._step_counter:
self._step_counter = step + 1
wall_time = time.time()
# buffer step row (upsert later)
self._pending_steps.append((step, wall_time))
# buffer metrics rows + key min/max updates
for k, v in self._log_to_commit.items():
k = k if isinstance(k, str) else str(k)
vr, vt = self._coerce_value(v)
self._pending_metrics.append((step, k, vr, vt))
if k in self._pending_key_minmax:
lo, hi = self._pending_key_minmax[k]
if step < lo:
lo = step
if step > hi:
hi = step
self._pending_key_minmax[k] = (lo, hi)
else:
self._pending_key_minmax[k] = (step, step)
self._log_to_commit = {}
# flush conditions
now = time.time()
if (
len(self._pending_metrics) >= self._flush_every_n
or (now - self._last_flush) >= self._flush_every_secs
):
self._flush()
# log image
def log_image(self, *args, **kwargs):
# this doesnt log images for now
pass
# finish logging
def finish(self):
if not self._started:
return
self._flush()
assert self._con is not None
self._con.close()
self._con = None
self._started = False
# -------------------------
# internal
# -------------------------
def _init_schema(self, con: sqlite3.Connection) -> None:
con.execute("BEGIN;")
con.execute("""
CREATE TABLE IF NOT EXISTS steps (
step INTEGER PRIMARY KEY,
wall_time REAL NOT NULL
);
""")
con.execute("""
CREATE TABLE IF NOT EXISTS metric_keys (
key TEXT PRIMARY KEY,
first_seen_step INTEGER,
last_seen_step INTEGER
);
""")
con.execute("""
CREATE TABLE IF NOT EXISTS metrics (
step INTEGER NOT NULL,
key TEXT NOT NULL,
value_real REAL,
value_text TEXT,
PRIMARY KEY (step, key),
FOREIGN KEY (step) REFERENCES steps(step) ON DELETE CASCADE
);
""")
con.execute(
"CREATE INDEX IF NOT EXISTS idx_metrics_key_step ON metrics (key, step);"
)
con.execute("COMMIT;")
def _coerce_value(self, v: Any) -> Tuple[Optional[float], Optional[str]]:
if v is None:
return None, None
if isinstance(v, bool):
return float(int(v)), None
if isinstance(v, (int, float)):
return float(v), None
try:
return float(v), None # type: ignore[arg-type]
except Exception:
return None, str(v)
def _flush(self) -> None:
if not self._pending_steps and not self._pending_metrics:
return
assert self._con is not None
con = self._con
con.execute("BEGIN;")
# steps upsert
if self._pending_steps:
con.executemany(
"INSERT INTO steps(step, wall_time) VALUES(?, ?) "
"ON CONFLICT(step) DO UPDATE SET wall_time=excluded.wall_time;",
self._pending_steps,
)
# keys table upsert (maintains list of keys + seen range)
if self._pending_key_minmax:
con.executemany(
"INSERT INTO metric_keys(key, first_seen_step, last_seen_step) VALUES(?, ?, ?) "
"ON CONFLICT(key) DO UPDATE SET "
"first_seen_step=MIN(metric_keys.first_seen_step, excluded.first_seen_step), "
"last_seen_step=MAX(metric_keys.last_seen_step, excluded.last_seen_step);",
[(k, lo, hi) for k, (lo, hi) in self._pending_key_minmax.items()],
)
# metrics upsert
if self._pending_metrics:
con.executemany(
"INSERT INTO metrics(step, key, value_real, value_text) VALUES(?, ?, ?, ?) "
"ON CONFLICT(step, key) DO UPDATE SET "
"value_real=excluded.value_real, value_text=excluded.value_text;",
self._pending_metrics,
)
con.execute("COMMIT;")
self._pending_steps.clear()
self._pending_metrics.clear()
self._pending_key_minmax.clear()
self._last_flush = time.time()
# create logger based on the logging config
def create_logger(logging_config: LoggingConfig, all_config: OrderedDict):
def create_logger(
logging_config: LoggingConfig,
all_config: OrderedDict,
save_root: Optional[str] = None,
):
if logging_config.use_wandb:
project_name = logging_config.project_name
run_name = logging_config.run_name
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
elif logging_config.use_ui_logger:
if save_root is None:
raise ValueError("save_root must be provided when using UILogger")
log_file = os.path.join(save_root, "loss_log.db")
return UILogger(log_file=log_file)
else:
return EmptyLogger()

884
ui/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,7 @@
"axios": "^1.7.9",
"classnames": "^2.5.1",
"lucide-react": "^0.475.0",
"next": "15.1.11",
"next": "^15.5.9",
"node-cache": "^5.1.2",
"prisma": "^6.3.1",
"react": "^19.0.0",
@@ -28,6 +28,7 @@
"react-global-hooks": "^1.3.5",
"react-icons": "^5.5.0",
"react-select": "^5.10.1",
"recharts": "^3.6.0",
"sqlite3": "^5.1.7",
"systeminformation": "^5.27.11",
"uuid": "^11.1.0",

View File

@@ -0,0 +1,98 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import path from 'path';
import fs from 'fs';
import { getTrainingFolder } from '@/server/settings';
import sqlite3 from 'sqlite3';
export const runtime = 'nodejs';
const prisma = new PrismaClient();
function openDb(filename: string) {
const db = new sqlite3.Database(filename);
db.configure('busyTimeout', 30_000);
return db;
}
function all<T = any>(db: sqlite3.Database, sql: string, params: any[] = []) {
return new Promise<T[]>((resolve, reject) => {
db.all(sql, params, (err, rows) => {
if (err) reject(err);
else resolve(rows as T[]);
});
});
}
function closeDb(db: sqlite3.Database) {
return new Promise<void>((resolve, reject) => {
db.close((err) => (err ? reject(err) : resolve()));
});
}
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
// this must be awaited to avoid TS error
const { jobID } = await params;
const job = await prisma.job.findUnique({ where: { id: jobID } });
if (!job) return NextResponse.json({ error: 'Job not found' }, { status: 404 });
const trainingFolder = await getTrainingFolder();
const jobFolder = path.join(trainingFolder, job.name);
const logPath = path.join(jobFolder, 'loss_log.db');
if (!fs.existsSync(logPath)) {
return NextResponse.json({ keys: [], key: 'loss', points: [] });
}
const url = new URL(request.url);
const key = url.searchParams.get('key') ?? 'loss';
const limit = Math.min(Number(url.searchParams.get('limit') ?? 2000), 20000);
const sinceStepParam = url.searchParams.get('since_step');
const sinceStep = sinceStepParam != null ? Number(sinceStepParam) : null;
const stride = Math.max(1, Number(url.searchParams.get('stride') ?? 1));
const db = openDb(logPath);
try {
const keysRows = await all<{ key: string }>(db, `SELECT key FROM metric_keys ORDER BY key ASC`);
const keys = keysRows.map((r) => r.key);
const points = await all<{
step: number;
wall_time: number;
value: number | null;
value_text: string | null;
}>(
db,
`
SELECT
m.step AS step,
s.wall_time AS wall_time,
m.value_real AS value,
m.value_text AS value_text
FROM metrics m
JOIN steps s ON s.step = m.step
WHERE m.key = ?
AND (? IS NULL OR m.step > ?)
AND (m.step % ?) = 0
ORDER BY m.step ASC
LIMIT ?
`,
[key, sinceStep, sinceStep, stride, limit]
);
return NextResponse.json({
key,
keys,
points: points.map((p) => ({
step: p.step,
wall_time: p.wall_time,
value: p.value ?? (p.value_text ? Number(p.value_text) : null),
})),
});
} finally {
await closeDb(db);
}
}

View File

@@ -10,9 +10,10 @@ import JobOverview from '@/components/JobOverview';
import { redirect } from 'next/navigation';
import JobActionBar from '@/components/JobActionBar';
import JobConfigViewer from '@/components/JobConfigViewer';
import JobLossGraph from '@/components/JobLossGraph';
import { Job } from '@prisma/client';
type PageKey = 'overview' | 'samples' | 'config';
type PageKey = 'overview' | 'samples' | 'config' | 'loss_log';
interface Page {
name: string;
@@ -36,6 +37,12 @@ const pages: Page[] = [
menuItem: SampleImagesMenu,
mainCss: 'pt-24',
},
{
name: 'Loss Graph',
value: 'loss_log',
component: JobLossGraph,
mainCss: 'pt-24',
},
{
name: 'Config File',
value: 'config',

View File

@@ -92,6 +92,10 @@ export const defaultJobConfig: JobConfig = {
switch_boundary_every: 1,
loss_type: 'mse',
},
logging: {
log_every: 1,
use_ui_logger: true,
},
model: {
name_or_path: 'ostris/Flex.1-alpha',
quantize: true,
@@ -187,5 +191,13 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
false) as boolean;
delete jobConfig.config.process[0].model.auto_memory;
}
if (!('logging' in jobConfig.config.process[0])) {
//@ts-ignore
jobConfig.config.process[0].logging = {
log_every: 1,
use_ui_logger: true,
};
}
return jobConfig;
};

View File

@@ -0,0 +1,432 @@
'use client';
import { Job } from '@prisma/client';
import useJobLossLog, { LossPoint } from '@/hooks/useJobLossLog';
import { useMemo, useState, useEffect } from 'react';
import { ResponsiveContainer, LineChart, Line, XAxis, YAxis, Tooltip, CartesianGrid, Legend } from 'recharts';
interface Props {
job: Job;
}
function formatNum(v: number) {
if (!Number.isFinite(v)) return '';
if (Math.abs(v) >= 1000) return v.toFixed(0);
if (Math.abs(v) >= 10) return v.toFixed(3);
if (Math.abs(v) >= 1) return v.toFixed(4);
return v.toPrecision(4);
}
function clamp01(x: number) {
return Math.max(0, Math.min(1, x));
}
// EMA smoothing that works on a per-series list.
// alpha=1 -> no smoothing, alpha closer to 0 -> more smoothing.
function emaSmoothPoints(points: { step: number; value: number }[], alpha: number) {
if (points.length === 0) return [];
const a = clamp01(alpha);
const out: { step: number; value: number }[] = new Array(points.length);
let prev = points[0].value;
out[0] = { step: points[0].step, value: prev };
for (let i = 1; i < points.length; i++) {
const x = points[i].value;
prev = a * x + (1 - a) * prev;
out[i] = { step: points[i].step, value: prev };
}
return out;
}
function hashToIndex(str: string, mod: number) {
let h = 2166136261;
for (let i = 0; i < str.length; i++) {
h ^= str.charCodeAt(i);
h = Math.imul(h, 16777619);
}
return Math.abs(h) % mod;
}
const PALETTE = [
'rgba(96,165,250,1)', // blue-400
'rgba(52,211,153,1)', // emerald-400
'rgba(167,139,250,1)', // purple-400
'rgba(251,191,36,1)', // amber-400
'rgba(244,114,182,1)', // pink-400
'rgba(248,113,113,1)', // red-400
'rgba(34,211,238,1)', // cyan-400
'rgba(129,140,248,1)', // indigo-400
];
function strokeForKey(key: string) {
return PALETTE[hashToIndex(key, PALETTE.length)];
}
export default function JobLossGraph({ job }: Props) {
const { series, lossKeys, status, refreshLoss } = useJobLossLog(job.id, 2000);
// Controls
const [useLogScale, setUseLogScale] = useState(false);
const [showRaw, setShowRaw] = useState(false);
const [showSmoothed, setShowSmoothed] = useState(true);
// 0..100 slider. 100 = no smoothing, 0 = heavy smoothing.
const [smoothing, setSmoothing] = useState(90);
// UI-only downsample for rendering speed
const [plotStride, setPlotStride] = useState(1);
// show only last N points in the chart (0 = all)
const [windowSize, setWindowSize] = useState<number>(4000);
// quick y clipping for readability
const [clipOutliers, setClipOutliers] = useState(false);
// which loss series are enabled (default: all enabled)
const [enabled, setEnabled] = useState<Record<string, boolean>>({});
// keep enabled map in sync with discovered keys (enable new ones automatically)
useEffect(() => {
setEnabled(prev => {
const next = { ...prev };
for (const k of lossKeys) {
if (next[k] === undefined) next[k] = true;
}
// drop removed keys
for (const k of Object.keys(next)) {
if (!lossKeys.includes(k)) delete next[k];
}
return next;
});
}, [lossKeys]);
const activeKeys = useMemo(() => lossKeys.filter(k => enabled[k] !== false), [lossKeys, enabled]);
const perSeries = useMemo(() => {
// Build per-series processed point arrays (raw + smoothed), then merge by step for charting.
const stride = Math.max(1, plotStride | 0);
// smoothing%: 0 => no smoothing (alpha=1.0), 100 => heavy smoothing (alpha=0.02)
const t = clamp01(smoothing / 100);
const alpha = 1.0 - t * 0.98; // 1.0 -> 0.02
const out: Record<string, { raw: { step: number; value: number }[]; smooth: { step: number; value: number }[] }> =
{};
for (const key of activeKeys) {
const pts: LossPoint[] = series[key] ?? [];
let raw = pts
.filter(p => p.value !== null && Number.isFinite(p.value as number))
.map(p => ({ step: p.step, value: p.value as number }))
.filter(p => (useLogScale ? p.value > 0 : true))
.filter((_, idx) => idx % stride === 0);
// windowing (applies after stride)
if (windowSize > 0 && raw.length > windowSize) {
raw = raw.slice(raw.length - windowSize);
}
const smooth = emaSmoothPoints(raw, alpha);
out[key] = { raw, smooth };
}
return out;
}, [series, activeKeys, smoothing, plotStride, windowSize, useLogScale]);
const chartData = useMemo(() => {
// Merge series into one array of objects keyed by step.
// Fields: `${key}__raw` and `${key}__smooth`
const map = new Map<number, any>();
for (const key of activeKeys) {
const s = perSeries[key];
if (!s) continue;
for (const p of s.raw) {
const row = map.get(p.step) ?? { step: p.step };
row[`${key}__raw`] = p.value;
map.set(p.step, row);
}
for (const p of s.smooth) {
const row = map.get(p.step) ?? { step: p.step };
row[`${key}__smooth`] = p.value;
map.set(p.step, row);
}
}
const arr = Array.from(map.values());
arr.sort((a, b) => a.step - b.step);
return arr;
}, [activeKeys, perSeries]);
const hasData = chartData.length > 1;
const yDomain = useMemo((): [number | 'auto', number | 'auto'] => {
if (!clipOutliers || chartData.length < 10) return ['auto', 'auto'];
// Collect visible values (prefer smoothed if shown, else raw)
const vals: number[] = [];
for (const row of chartData) {
for (const key of activeKeys) {
const k = showSmoothed ? `${key}__smooth` : `${key}__raw`;
const v = row[k];
if (typeof v === 'number' && Number.isFinite(v)) vals.push(v);
}
}
if (vals.length < 10) return ['auto', 'auto'];
vals.sort((a, b) => a - b);
const lo = vals[Math.floor(vals.length * 0.02)];
const hi = vals[Math.ceil(vals.length * 0.98) - 1];
if (!Number.isFinite(lo) || !Number.isFinite(hi) || lo === hi) return ['auto', 'auto'];
return [lo, hi];
}, [clipOutliers, chartData, activeKeys, showSmoothed]);
const latestSummary = useMemo(() => {
// Provide a simple “latest” readout for the first active series
const firstKey = activeKeys[0];
if (!firstKey) return null;
const s = perSeries[firstKey];
if (!s) return null;
const lastRaw = s.raw.length ? s.raw[s.raw.length - 1] : null;
const lastSmooth = s.smooth.length ? s.smooth[s.smooth.length - 1] : null;
return {
key: firstKey,
step: lastRaw?.step ?? lastSmooth?.step ?? null,
raw: lastRaw?.value ?? null,
smooth: lastSmooth?.value ?? null,
};
}, [activeKeys, perSeries]);
return (
<div className="bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800 flex flex-col">
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
<div className="flex items-center gap-2">
<div className="h-2 w-2 rounded-full bg-blue-400" />
<h2 className="text-gray-100 text-sm font-medium">Loss graph</h2>
<span className="text-xs text-gray-400">
{status === 'loading' && 'Loading...'}
{status === 'refreshing' && 'Refreshing...'}
{status === 'error' && 'Error'}
{status === 'success' && hasData && `${chartData.length.toLocaleString()} steps`}
{status === 'success' && !hasData && 'No data yet'}
</span>
</div>
<button
type="button"
onClick={refreshLoss}
className="px-3 py-1 rounded-md text-xs bg-gray-700/60 hover:bg-gray-700 text-gray-200 border border-gray-700"
>
Refresh
</button>
</div>
{/* Chart */}
<div className="px-4 pt-4 pb-4">
<div className="bg-gray-950 rounded-lg border border-gray-800 h-96 relative">
{!hasData ? (
<div className="h-full w-full flex items-center justify-center text-sm text-gray-400">
{status === 'error' ? 'Failed to load loss logs.' : 'Waiting for loss points...'}
</div>
) : (
<ResponsiveContainer width="100%" height="100%">
<LineChart data={chartData} margin={{ top: 10, right: 16, bottom: 10, left: 8 }}>
<CartesianGrid strokeDasharray="3 3" stroke="rgba(255,255,255,0.06)" />
<XAxis
dataKey="step"
tick={{ fill: 'rgba(255,255,255,0.55)', fontSize: 12 }}
tickLine={{ stroke: 'rgba(255,255,255,0.15)' }}
axisLine={{ stroke: 'rgba(255,255,255,0.15)' }}
minTickGap={40}
/>
<YAxis
scale={useLogScale ? 'log' : 'linear'}
tick={{ fill: 'rgba(255,255,255,0.55)', fontSize: 12 }}
tickLine={{ stroke: 'rgba(255,255,255,0.15)' }}
axisLine={{ stroke: 'rgba(255,255,255,0.15)' }}
width={72}
tickFormatter={formatNum}
domain={yDomain}
allowDataOverflow={clipOutliers}
/>
<Tooltip
cursor={{ stroke: 'rgba(59,130,246,0.25)', strokeWidth: 1 }}
contentStyle={{
background: 'rgba(17,24,39,0.96)',
border: '1px solid rgba(31,41,55,1)',
borderRadius: 10,
color: 'rgba(255,255,255,0.9)',
fontSize: 12,
}}
labelStyle={{ color: 'rgba(255,255,255,0.75)' }}
labelFormatter={(label: any) => `step ${label}`}
formatter={(value: any, name: any) => [formatNum(Number(value)), name]}
/>
<Legend
wrapperStyle={{
paddingTop: 8,
color: 'rgba(255,255,255,0.7)',
fontSize: 12,
}}
/>
{activeKeys.map(k => {
const color = strokeForKey(k);
return (
<g key={k}>
{showRaw && (
<Line
type="monotone"
dataKey={`${k}__raw`}
name={`${k} (raw)`}
stroke={color.replace('1)', '0.40)')}
strokeWidth={1.25}
dot={false}
isAnimationActive={false}
/>
)}
{showSmoothed && (
<Line
type="monotone"
dataKey={`${k}__smooth`}
name={`${k}`}
stroke={color}
strokeWidth={2}
dot={false}
isAnimationActive={false}
/>
)}
</g>
);
})}
</LineChart>
</ResponsiveContainer>
)}
</div>
</div>
{/* Controls */}
<div className="px-4 pb-2">
<div className="grid grid-cols-1 md:grid-cols-2 gap-3">
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3">
<label className="block text-xs text-gray-400 mb-2">Display</label>
<div className="flex flex-wrap gap-2">
<ToggleButton checked={showSmoothed} onClick={() => setShowSmoothed(v => !v)} label="Smoothed" />
<ToggleButton checked={showRaw} onClick={() => setShowRaw(v => !v)} label="Raw" />
<ToggleButton checked={useLogScale} onClick={() => setUseLogScale(v => !v)} label="Log Y" />
<ToggleButton checked={clipOutliers} onClick={() => setClipOutliers(v => !v)} label="Clip outliers" />
</div>
</div>
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3">
<label className="block text-xs text-gray-400 mb-2">Series</label>
{lossKeys.length === 0 ? (
<div className="text-sm text-gray-400">No loss keys found yet.</div>
) : (
<div className="flex flex-wrap gap-2">
{lossKeys.map(k => (
<button
key={k}
type="button"
onClick={() => setEnabled(prev => ({ ...prev, [k]: !(prev[k] ?? true) }))}
className={[
'px-3 py-1 rounded-md text-xs border transition-colors',
enabled[k] === false
? 'bg-gray-900 text-gray-400 border-gray-800 hover:bg-gray-800/60'
: 'bg-gray-900 text-gray-200 border-gray-800 hover:bg-gray-800/60',
].join(' ')}
aria-pressed={enabled[k] !== false}
title={k}
>
<span className="inline-block h-2 w-2 rounded-full mr-2" style={{ background: strokeForKey(k) }} />
{k}
</button>
))}
</div>
)}
</div>
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3">
<div className="flex items-center justify-between mb-1">
<label className="block text-xs text-gray-400">Smoothing</label>
<span className="text-xs text-gray-300">{smoothing}%</span>
</div>
<input
type="range"
min={0}
max={100}
value={smoothing}
onChange={e => setSmoothing(Number(e.target.value))}
className="w-full accent-blue-500"
disabled={!showSmoothed}
/>
</div>
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3">
<div className="flex items-center justify-between mb-1">
<label className="block text-xs text-gray-400">Plot stride</label>
<span className="text-xs text-gray-300">every {plotStride} pt</span>
</div>
<input
type="range"
min={1}
max={20}
value={plotStride}
onChange={e => setPlotStride(Number(e.target.value))}
className="w-full accent-blue-500"
/>
<div className="mt-2 text-[11px] text-gray-500">UI downsample for huge runs.</div>
</div>
<div className="bg-gray-950 border border-gray-800 rounded-lg p-3 md:col-span-2">
<div className="flex items-center justify-between mb-1">
<label className="block text-xs text-gray-400">Window (last N points)</label>
<span className="text-xs text-gray-300">{windowSize === 0 ? 'all' : windowSize.toLocaleString()}</span>
</div>
<input
type="range"
min={0}
max={20000}
step={250}
value={windowSize}
onChange={e => setWindowSize(Number(e.target.value))}
className="w-full accent-blue-500"
/>
<div className="mt-2 text-[11px] text-gray-500">
Set to 0 to show all (not recommended for very long runs).
</div>
</div>
</div>
</div>
</div>
);
}
function ToggleButton({ checked, onClick, label }: { checked: boolean; onClick: () => void; label: string }) {
return (
<button
type="button"
onClick={onClick}
className={[
'px-3 py-1 rounded-md text-xs border transition-colors',
checked
? 'bg-blue-500/10 text-blue-300 border-blue-500/30 hover:bg-blue-500/15'
: 'bg-gray-900 text-gray-300 border-gray-800 hover:bg-gray-800/60',
].join(' ')}
aria-pressed={checked}
>
{label}
</button>
);
}

View File

@@ -0,0 +1,145 @@
'use client';
import { useEffect, useState, useRef, useCallback, useMemo } from 'react';
import { apiClient } from '@/utils/api';
export interface LossPoint {
step: number;
wall_time?: number;
value: number | null;
}
type SeriesMap = Record<string, LossPoint[]>;
function isLossKey(key: string) {
// treat anything containing "loss" as a loss-series
// (covers loss, train_loss, val_loss, loss/xyz, etc.)
return /loss/i.test(key);
}
export default function useJobLossLog(jobID: string, reloadInterval: null | number = null) {
const [series, setSeries] = useState<SeriesMap>({});
const [keys, setKeys] = useState<string[]>([]);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle');
const didInitialLoadRef = useRef(false);
const inFlightRef = useRef(false);
// track last step per key so polling is incremental per series
const lastStepByKeyRef = useRef<Record<string, number | null>>({});
const lossKeys = useMemo(() => {
const base = (keys ?? []).filter(isLossKey);
// if keys table is empty early on, fall back to just "loss"
if (base.length === 0) return ['loss'];
return base.sort();
}, [keys]);
const refreshLoss = useCallback(async () => {
if (!jobID) return;
if (inFlightRef.current) return;
inFlightRef.current = true;
const loadStatus: 'loading' | 'refreshing' = didInitialLoadRef.current ? 'refreshing' : 'loading';
setStatus(loadStatus);
try {
// Step 1: get key list (we can do this by calling endpoint once; it returns keys)
// Keep it cheap: limit=1.
const first = await apiClient
.get(`/api/jobs/${jobID}/loss`, { params: { key: 'loss', limit: 1 } })
.then(res => res.data as { keys?: string[] });
const newKeys = first.keys ?? [];
setKeys(newKeys);
const wantedLossKeys = (newKeys.filter(isLossKey).length ? newKeys.filter(isLossKey) : ['loss']).sort();
// Step 2: fetch each loss key incrementally (since_step per key if polling)
const requests = wantedLossKeys.map(k => {
const params: Record<string, any> = { key: k };
if (reloadInterval && lastStepByKeyRef.current[k] != null) {
params.since_step = lastStepByKeyRef.current[k];
}
// keep default limit from server (or set explicitly if you want)
// params.limit = 2000;
return apiClient
.get(`/api/jobs/${jobID}/loss`, { params })
.then(res => res.data as { key: string; points?: LossPoint[] });
});
const results = await Promise.all(requests);
setSeries(prev => {
const next: SeriesMap = { ...prev };
for (const r of results) {
const k = r.key;
const newPoints = (r.points ?? []).filter(p => p.value !== null);
if (!didInitialLoadRef.current) {
// initial: replace
next[k] = newPoints;
} else if (newPoints.length) {
const existing = next[k] ?? [];
const prevLast = existing.length ? existing[existing.length - 1].step : null;
const filtered = prevLast == null ? newPoints : newPoints.filter(p => p.step > prevLast);
next[k] = filtered.length ? [...existing, ...filtered] : existing;
} else {
// no new points: keep existing
next[k] = next[k] ?? [];
}
// update last step per key
const finalArr = next[k] ?? [];
lastStepByKeyRef.current[k] = finalArr.length
? finalArr[finalArr.length - 1].step
: (lastStepByKeyRef.current[k] ?? null);
}
// remove stale loss keys that no longer exist (rare, but keeps UI clean)
for (const existingKey of Object.keys(next)) {
if (isLossKey(existingKey) && !wantedLossKeys.includes(existingKey)) {
delete next[existingKey];
delete lastStepByKeyRef.current[existingKey];
}
}
return next;
});
setStatus('success');
didInitialLoadRef.current = true;
} catch (err) {
console.error('Error fetching loss logs:', err);
setStatus('error');
} finally {
inFlightRef.current = false;
}
}, [jobID, reloadInterval]);
useEffect(() => {
// reset when job changes
didInitialLoadRef.current = false;
lastStepByKeyRef.current = {};
setSeries({});
setKeys([]);
setStatus('idle');
refreshLoss();
if (reloadInterval) {
const interval = setInterval(() => {
refreshLoss();
}, reloadInterval);
return () => clearInterval(interval);
}
}, [jobID, reloadInterval, refreshLoss]);
return { series, keys, lossKeys, status, refreshLoss, setSeries };
}

View File

@@ -197,6 +197,11 @@ export interface SampleConfig {
fps: number;
}
export interface LoggingConfig {
log_every: number;
use_ui_logger: boolean;
}
export interface SliderConfig {
guidance_strength?: number;
anchor_strength?: number;
@@ -218,6 +223,7 @@ export interface ProcessConfig {
save: SaveConfig;
datasets: DatasetConfig[];
train: TrainConfig;
logging: LoggingConfig;
model: ModelConfig;
sample: SampleConfig;
}

View File

@@ -1 +1 @@
VERSION = "0.7.9"
VERSION = "0.7.10"