Merge pull request #256 from ostris/ui

Added AI-Toolkit UI
This commit is contained in:
Jaret Burkett
2025-02-23 16:10:49 -07:00
committed by GitHub
98 changed files with 10345 additions and 215 deletions

4
.gitignore vendored
View File

@@ -161,6 +161,7 @@ cython_debug/
/env.sh
/models
/datasets
/custom/*
!/custom/.gitkeep
/.tmp
@@ -177,4 +178,5 @@ cython_debug/
/wandb
.vscode/settings.json
.DS_Store
._.DS_Store
._.DS_Store
aitk_db.db

223
README.md
View File

@@ -7,7 +7,7 @@
</a>
I am transitioning to working on my open source AI projects full time. If you find my work useful, please consider supporting me on [Patreon](https://www.patreon.com/ostris). I will be able to work on more projects and provide better support with your help.
I work on open source full time, which means I 100% rely on donations to make a living. If you find this project helpful, or use it in for commercial purposes, please consider donating to support my work on [Patreon](https://www.patreon.com/ostris) or [Github Sponsors](https://github.com/sponsors/ostris).
## Installation
@@ -18,7 +18,6 @@ Requirements:
- git
Linux:
```bash
git clone https://github.com/ostris/ai-toolkit.git
@@ -43,6 +42,43 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
```
# AI Toolkit UI
<img src="https://ostris.com/wp-content/uploads/2025/02/toolkit-ui.jpg" alt="AI Toolkit UI" width="100%">
The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It is still in early beta and will likely have bugs and frequent breaking changes. It is currently only tested on linux for now.
WARNING: The UI is not secure and should not be exposed to the internet. It is only meant to be run locally or on a server that does not have ports exposed. Adding additional security is on the roadmap.
## Installing the UI
Requirements:
- Node.js > 18
You will need to do this with every update as well.
```bash
cd ui
npm install
npm run build
npm run update_db
```
## Running the UI
Make sure you built it as shown above. The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs.
```bash
cd ui
npm run start
```
You can now access the UI at `http://localhost:8675` or `http://<your-ip>:8675` if you are running it on a server.
## FLUX.1 Training
### Tutorial
@@ -275,186 +311,3 @@ You can also exclude layers by their names by using `ignore_if_contains` network
`ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
if will be ignored.
---
## EVERYTHING BELOW THIS LINE IS OUTDATED
It may still work like that, but I have not tested it in a while.
---
### Batch Image Generation
A image generator that can take frompts from a config file or form a txt file and generate them to a
folder. I mainly needed this for an SDXL test I am doing but added some polish to it so it can be used
for generat batch image generation.
It all runs off a config file, which you can find an example of in `config/examples/generate.example.yaml`.
Mere info is in the comments in the example
---
### LoRA (lierla), LoCON (LyCORIS) extractor
It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features
and LoRA (lierla) support. It can do multiple types of extractions in one run.
It all runs off a config file, which you can find an example of in `config/examples/extract.example.yml`.
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
Then you can edit the file to your liking. and call it like so:
```bash
python3 run.py config/whatever_you_want.yml
```
You can also put a full path to a config file, if you want to keep it somewhere else.
```bash
python3 run.py "/home/user/whatever_you_want.yml"
```
More notes on how it works are available in the example config file itself. LoRA and LoCON both support
extractions of 'fixed', 'threshold', 'ratio', 'quantile'. I'll update what these do and mean later.
Most people used fixed, which is traditional fixed dimension extraction.
`process` is an array of different processes to run. You can add a few and mix and match. One LoRA, one LyCON, etc.
---
### LoRA Rescale
Change `<lora:my_lora:4.6>` to `<lora:my_lora:1.0>` or whatever you want with the same effect.
A tool for rescaling a LoRA's weights. Should would with LoCON as well, but I have not tested it.
It all runs off a config file, which you can find an example of in `config/examples/mod_lora_scale.yml`.
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
Then you can edit the file to your liking. and call it like so:
```bash
python3 run.py config/whatever_you_want.yml
```
You can also put a full path to a config file, if you want to keep it somewhere else.
```bash
python3 run.py "/home/user/whatever_you_want.yml"
```
More notes on how it works are available in the example config file itself. This is useful when making
all LoRAs, as the ideal weight is rarely 1.0, but now you can fix that. For sliders, they can have weird scales form -2 to 2
or even -15 to 15. This will allow you to dile it in so they all have your desired scale
---
### LoRA Slider Trainer
<a target="_blank" href="https://colab.research.google.com/github/ostris/ai-toolkit/blob/main/notebooks/SliderTraining.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
This is how I train most of the recent sliders I have on Civitai, you can check them out in my [Civitai profile](https://civitai.com/user/Ostris/models).
It is based off the work by [p1atdev/LECO](https://github.com/p1atdev/LECO) and [rohitgandikota/erasing](https://github.com/rohitgandikota/erasing)
But has been heavily modified to create sliders rather than erasing concepts. I have a lot more plans on this, but it is
very functional as is. It is also very easy to use. Just copy the example config file in `config/examples/train_slider.example.yml`
to the `config` folder and rename it to `whatever_you_want.yml`. Then you can edit the file to your liking. and call it like so:
```bash
python3 run.py config/whatever_you_want.yml
```
There is a lot more information in that example file. You can even run the example as is without any modifications to see
how it works. It will create a slider that turns all animals into dogs(neg) or cats(pos). Just run it like so:
```bash
python3 run.py config/examples/train_slider.example.yml
```
And you will be able to see how it works without configuring anything. No datasets are required for this method.
I will post an better tutorial soon.
---
## Extensions!!
You can now make and share custom extensions. That run within this framework and have all the inbuilt tools
available to them. I will probably use this as the primary development method going
forward so I dont keep adding and adding more and more features to this base repo. I will likely migrate a lot
of the existing functionality as well to make everything modular. There is an example extension in the `extensions`
folder that shows how to make a model merger extension. All of the code is heavily documented which is hopefully
enough to get you started. To make an extension, just copy that example and replace all the things you need to.
### Model Merger - Example Extension
It is located in the `extensions` folder. It is a fully finctional model merger that can merge as many models together
as you want. It is a good example of how to make an extension, but is also a pretty useful feature as well since most
mergers can only do one model at a time and this one will take as many as you want to feed it. There is an
example config file in there, just copy that to your `config` folder and rename it to `whatever_you_want.yml`.
and use it like any other config file.
## WIP Tools
### VAE (Variational Auto Encoder) Trainer
This works, but is not ready for others to use and therefore does not have an example config.
I am still working on it. I will update this when it is ready.
I am adding a lot of features for criteria that I have used in my image enlargement work. A Critic (discriminator),
content loss, style loss, and a few more. If you don't know, the VAE
for stable diffusion (yes even the MSE one, and SDXL), are horrible at smaller faces and it holds SD back. I will fix this.
I'll post more about this later with better examples later, but here is a quick test of a run through with various VAEs.
Just went in and out. It is much worse on smaller faces than shown here.
<img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">
---
## TODO
- [X] Add proper regs on sliders
- [X] Add SDXL support (base model only for now)
- [ ] Add plain erasing
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
---
## Change Log
#### 2023-08-05
- Huge memory rework and slider rework. Slider training is better thant ever with no more
ram spikes. I also made it so all 4 parts of the slider algorythm run in one batch so they share gradient
accumulation. This makes it much faster and more stable.
- Updated the example config to be something more practical and more updated to current methods. It is now
a detail slide and shows how to train one without a subject. 512x512 slider training for 1.5 should work on
6GB gpu now. Will test soon to verify.
#### 2021-10-20
- Windows support bug fixes
- Extensions! Added functionality to make and share custom extensions for training, merging, whatever.
check out the example in the `extensions` folder. Read more about that above.
- Model Merging, provided via the example extension.
#### 2023-08-03
Another big refactor to make SD more modular.
Made batch image generation script
#### 2023-08-01
Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable
at the moment, so hopefully there are not breaking changes.
Unfortunately, I am too lazy to write a proper changelog with all the changes.
I added SDXL training to sliders... but.. it does not work properly.
The slider training relies on a model's ability to understand that an unconditional (negative prompt)
means you do not want that concept in the output. SDXL does not understand this for whatever reason,
which makes separating out
concepts within the model hard. I am sure the community will find a way to fix this
over time, but for now, it is not
going to work properly. And if any of you are thinking "Could we maybe fix it by adding 1 or 2 more text
encoders to the model as well as a few more entirely separate diffusion networks?" No. God no. It just needs a little
training without every experimental new paper added to it. The KISS principal.
#### 2023-07-30
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
regularizer. You can set the network multiplier to force spread consistency at high weights

View File

@@ -0,0 +1,227 @@
from collections import OrderedDict
import os
import sqlite3
import asyncio
import concurrent.futures
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
from typing import Literal, Optional
AITK_Status = Literal["running", "stopped", "error", "completed"]
class UITrainer(SDTrainer):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
print(f"Using SQLite database at {self.sqlite_db_path}")
self.job_id = os.environ.get("AITK_JOB_ID", None)
if self.job_id is None:
raise Exception("AITK_JOB_ID not set")
self.is_stopping = False
# Create a thread pool for database operations
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Track all async tasks
self._async_tasks = []
# Initialize the status
self._run_async_operation(self._update_status("running", "Starting"))
def _run_async_operation(self, coro):
"""Helper method to run an async coroutine and track the task."""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No event loop exists, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a task and track it
if loop.is_running():
task = asyncio.run_coroutine_threadsafe(coro, loop)
self._async_tasks.append(asyncio.wrap_future(task))
else:
task = loop.create_task(coro)
self._async_tasks.append(task)
loop.run_until_complete(task)
async def _execute_db_operation(self, operation_func):
"""Execute a database operation in a separate thread to avoid blocking."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.thread_pool, operation_func)
def _db_connect(self):
"""Create a new connection for each operation to avoid locking."""
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
conn.isolation_level = None # Enable autocommit mode
return conn
def should_stop(self):
def _check_stop():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT stop FROM Job WHERE id = ?", (self.job_id,))
stop = cursor.fetchone()
return False if stop is None else stop[0] == 1
return _check_stop()
def maybe_stop(self):
if self.should_stop():
self._run_async_operation(
self._update_status("stopped", "Job stopped"))
self.is_stopping = True
raise Exception("Job stopped")
async def _update_key(self, key, value):
if not self.accelerator.is_main_process:
return
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE")
try:
# Convert the value to string if it's not already
if isinstance(value, str):
value_to_insert = value
else:
value_to_insert = str(value)
# Use parameterized query for both the column name and value
update_query = f"UPDATE Job SET {key} = ? WHERE id = ?"
cursor.execute(
update_query, (value_to_insert, self.job_id))
finally:
cursor.execute("COMMIT")
await self._execute_db_operation(_do_update)
def update_step(self):
"""Non-blocking update of the step count."""
if self.accelerator.is_main_process:
self._run_async_operation(self._update_key("step", self.step_num))
def update_db_key(self, key, value):
"""Non-blocking update a key in the database."""
if self.accelerator.is_main_process:
self._run_async_operation(self._update_key(key, value))
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
if not self.accelerator.is_main_process:
return
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE")
try:
if info is not None:
cursor.execute(
"UPDATE Job SET status = ?, info = ? WHERE id = ?",
(status, info, self.job_id)
)
else:
cursor.execute(
"UPDATE Job SET status = ? WHERE id = ?",
(status, self.job_id)
)
finally:
cursor.execute("COMMIT")
await self._execute_db_operation(_do_update)
def update_status(self, status: AITK_Status, info: Optional[str] = None):
"""Non-blocking update of status."""
if self.accelerator.is_main_process:
self._run_async_operation(self._update_status(status, info))
async def wait_for_all_async(self):
"""Wait for all tracked async operations to complete."""
if not self._async_tasks:
return
try:
await asyncio.gather(*self._async_tasks)
finally:
# Clear the task list after completion
self._async_tasks.clear()
def on_error(self, e: Exception):
super(UITrainer, self).on_error(e)
if self.accelerator.is_main_process and not self.is_stopping:
self.update_status("error", str(e))
self.update_db_key("step", self.last_save_step)
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def handle_timing_print_hook(self, timing_dict):
if "train_loop" not in timing_dict:
print("train_loop not found in timing_dict", timing_dict)
return
seconds_per_iter = timing_dict["train_loop"]
# determine iter/sec or sec/iter
if seconds_per_iter < 1:
iters_per_sec = 1 / seconds_per_iter
self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec")
else:
self.update_db_key(
"speed_string", f"{seconds_per_iter:.2f} sec/iter")
def done_hook(self):
super(UITrainer, self).done_hook()
self.update_status("completed", "Training completed")
# Wait for all async operations to finish before shutting down
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def end_step_hook(self):
super(UITrainer, self).end_step_hook()
self.update_step()
self.maybe_stop()
def hook_before_model_load(self):
super().hook_before_model_load()
self.maybe_stop()
self.update_status("running", "Loading model")
def before_dataset_load(self):
super().before_dataset_load()
self.maybe_stop()
self.update_status("running", "Loading dataset")
def hook_before_train_loop(self):
super().hook_before_train_loop()
self.maybe_stop()
self.update_step()
self.update_status("running", "Training")
self.timer.add_after_print_hook(self.handle_timing_print_hook)
def status_update_hook_func(self, string):
self.update_status("running", string)
def hook_after_sd_init_before_load(self):
super().hook_after_sd_init_before_load()
self.maybe_stop()
self.sd.add_status_update_hook(self.status_update_hook_func)
def sample_step_hook(self, img_num, total_imgs):
super().sample_step_hook(img_num, total_imgs)
self.maybe_stop()
self.update_status(
"running", f"Generating images - {img_num + 1}/{total_imgs}")
def sample(self, step=None, is_first=False):
self.maybe_stop()
total_imgs = len(self.sample_config.prompts)
self.update_status("running", f"Generating images - 0/{total_imgs}")
super().sample(step, is_first)
self.maybe_stop()
self.update_status("running", "Training")
def save(self, step=None):
self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step)
self.maybe_stop()
self.update_status("running", "Training")

View File

@@ -18,6 +18,22 @@ class SDTrainerExtension(Extension):
from .SDTrainer import SDTrainer
return SDTrainer
# This is for generic training (LoRA, Dreambooth, FineTuning)
class UITrainerExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "ui_trainer"
# name is the name of the extension for printing
name = "UI Trainer"
# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .UITrainer import UITrainer
return UITrainer
# for backwards compatability
class TextualInversionTrainer(SDTrainerExtension):
@@ -26,5 +42,5 @@ class TextualInversionTrainer(SDTrainerExtension):
AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
SDTrainerExtension, TextualInversionTrainer
SDTrainerExtension, TextualInversionTrainer, UITrainerExtension
]

View File

@@ -24,6 +24,9 @@ class BaseProcess(object):
self.performance_log_every = self.get_conf('performance_log_every', 0)
print(json.dumps(self.config, indent=4))
def on_error(self, e: Exception):
pass
def get_conf(self, key, default=None, required=False, as_type=None):
# split key by '.' and recursively get the value

View File

@@ -92,6 +92,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.step_num = 0
self.start_step = 0
self.epoch_num = 0
self.last_save_step = 0
# start at 1 so we can do a sample at the start
self.grad_accumulation_step = 1
# if true, then we do not do an optimizer step. We are accumulating gradients
@@ -439,6 +440,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
def post_save_hook(self, save_path):
# override in subclass
pass
def done_hook(self):
pass
def end_step_hook(self):
pass
def save(self, step=None):
if not self.accelerator.is_main_process:
@@ -453,6 +460,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
step_num = ''
if step is not None:
self.last_save_step = step
# zeropad 9 digits
step_num = f"_{str(step).zfill(9)}"
@@ -648,6 +656,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.logger.start()
self.prepare_accelerator()
def sample_step_hook(self, img_num, total_imgs):
pass
def prepare_accelerator(self):
# set some config
@@ -722,6 +732,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_train_loop(self, batch):
# return loss
return 0.0
def hook_after_sd_init_before_load(self):
pass
def get_latest_save_path(self, name=None, post=''):
if name == None:
@@ -1417,8 +1430,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
self.hook_after_sd_init_before_load()
# run base sd process run
self.sd.load_model()
self.sd.add_after_sample_image_hook(self.sample_step_hook)
dtype = get_torch_dtype(self.train_config.dtype)
@@ -1812,6 +1829,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd)
flush()
self.last_save_step = self.step_num
### HOOK ###
self.hook_before_train_loop()
@@ -2091,6 +2109,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# update various steps
self.step_num = step + 1
self.grad_accumulation_step += 1
self.end_step_hook()
###################################################################
@@ -2110,13 +2129,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.logger.finish()
self.accelerator.end_training()
if self.save_config.push_to_hub:
if("HF_TOKEN" not in os.environ):
interpreter_login(new_session=False, write_permission=True)
self.push_to_hub(
repo_id=self.save_config.hf_repo_id,
private=self.save_config.hf_private
)
if self.accelerator.is_main_process:
# push to hub
if self.save_config.push_to_hub:
if("HF_TOKEN" not in os.environ):
interpreter_login(new_session=False, write_permission=True)
self.push_to_hub(
repo_id=self.save_config.hf_repo_id,
private=self.save_config.hf_private
)
del (
self.sd,
unet,
@@ -2128,6 +2149,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
flush()
self.done_hook()
def push_to_hub(
self,

View File

@@ -32,4 +32,5 @@ sentencepiece
huggingface_hub
peft
gradio
python-slugify
python-slugify
sqlite3

4
run.py
View File

@@ -88,6 +88,10 @@ def main():
except Exception as e:
print_acc(f"Error running job: {e}")
jobs_failed += 1
try:
job.process[0].on_error(e)
except Exception as e2:
print_acc(f"Error running on_error: {e2}")
if not args.recover:
print_end_message(jobs_completed, jobs_failed)
raise e

View File

@@ -379,7 +379,8 @@ class TrainConfig:
self.do_prior_divergence = kwargs.get('do_prior_divergence', False)
ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
if ema_config is not None:
# if it is set explicitly to false, leave it false.
if ema_config is not None and ema_config.get('use_ema', None) is not None:
ema_config['use_ema'] = True
print(f"Using EMA")
else:

View File

@@ -9,6 +9,7 @@ from typing import List, Optional, Dict, Type, Union
import torch
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
from transformers import CLIPTextModel
from toolkit.models.lokr import LokrModule
from .config_modules import NetworkConfig
from .lorm import count_parameters

282
toolkit/models/lokr.py Normal file
View File

@@ -0,0 +1,282 @@
# based heavily on https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from toolkit.network_mixins import ToolkitModuleMixin
from typing import TYPE_CHECKING, Union, List
if TYPE_CHECKING:
from toolkit.lora_special import LoRASpecialNetwork
# 4, build custom backward function
# -
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1
128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8
250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2
360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
return m, n
if factor == -1:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n
def make_weight_cp(t, wa, wb):
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
return rebuild2
def make_kron(w1, w2, scale):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
rebuild = torch.kron(w1, w2)
return rebuild*scale
class LokrModule(ToolkitModuleMixin, nn.Module):
"""
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=0.,
rank_dropout=0.,
module_dropout=0.,
use_cp=False,
decompose_both = False,
network: 'LoRASpecialNetwork' = None,
factor:int=-1, # factorization factor
**kwargs,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
ToolkitModuleMixin.__init__(self, network=network)
torch.nn.Module.__init__(self)
factor = int(factor)
self.lora_name = lora_name
self.lora_dim = lora_dim
self.cp = False
self.use_w1 = False
self.use_w2 = False
self.shape = org_module.weight.shape
if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels
k_size = org_module.kernel_size
out_dim = org_module.out_channels
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
self.cp = use_cp and k_size!=(1, 1)
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
if lora_dim >= max(shape[0][1], shape[1][1])/2:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
elif self.cp:
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
else: # Conv2d not cp
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
self.op = F.conv2d
self.extra_args = {
"stride": org_module.stride,
"padding": org_module.padding,
"dilation": org_module.dilation,
"groups": org_module.groups
}
else: # Linear
in_dim = org_module.in_features
out_dim = org_module.out_features
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
# smaller part. weight scale
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
else:
self.use_w1 = True
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
if lora_dim < max(shape[0][1], shape[1][1])/2:
# bigger part. weight and LoRA. [b, dim] x [dim, d]
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
else:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
self.op = F.linear
self.extra_args = {}
self.dropout = dropout
if dropout:
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if isinstance(alpha, torch.Tensor):
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
if self.use_w2 and self.use_w1:
#use scale = 1
alpha = lora_dim
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
if self.use_w2:
torch.nn.init.constant_(self.lokr_w2, 0)
else:
if self.cp:
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
torch.nn.init.constant_(self.lokr_w2_b, 0)
if self.use_w1:
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
else:
torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5))
self.multiplier = multiplier
self.org_module = [org_module]
weight = make_kron(
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
else self.lokr_w2_a@self.lokr_w2_b),
torch.tensor(self.multiplier * self.scale)
)
assert torch.sum(torch.isnan(weight)) == 0, "weight is nan"
# Same as locon.py
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, orig_weight = None):
weight = make_kron(
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
(self.lokr_w2 if self.use_w2
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
else self.lokr_w2_a@self.lokr_w2_b),
torch.tensor(self.scale)
)
if orig_weight is not None:
weight = weight.reshape(orig_weight.shape)
if self.training and self.rank_dropout:
drop = torch.rand(weight.size(0)) < self.rank_dropout
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
return weight
@torch.no_grad()
def apply_max_norm(self, max_norm, device=None):
orig_norm = self.get_weight().norm()
norm = torch.clamp(orig_norm, max_norm/2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu()/norm.cpu()
scaled = ratio.item() != 1.0
if scaled:
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
if self.use_w1:
self.lokr_w1 *= ratio**(1/modules)
else:
self.lokr_w1_a *= ratio**(1/modules)
self.lokr_w1_b *= ratio**(1/modules)
if self.use_w2:
self.lokr_w2 *= ratio**(1/modules)
else:
if self.cp:
self.lokr_t2 *= ratio**(1/modules)
self.lokr_w2_a *= ratio**(1/modules)
self.lokr_w2_b *= ratio**(1/modules)
return scaled, orig_norm*ratio
def forward(self, x):
if self.module_dropout and self.training:
if torch.rand(1) < self.module_dropout:
return self.op(
x,
self.org_module[0].weight.data,
None if self.org_module[0].bias is None else self.org_module[0].bias.data
)
weight = (
self.org_module[0].weight.data
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
)
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
return self.op(
x,
weight.view(self.shape),
bias,
**self.extra_args
)

View File

@@ -202,6 +202,8 @@ class StableDiffusion:
# merge in and preview active with -1 weight
self.invert_assistant_lora = False
self._after_sample_img_hooks = []
self._status_update_hooks = []
def load_model(self):
if self.is_loaded:
@@ -540,10 +542,10 @@ class StableDiffusion:
tokenizer = pipe.tokenizer
elif self.model_config.is_flux:
print_acc("Loading Flux model")
self.print_and_status_update("Loading Flux model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
print_acc("Loading transformer")
self.print_and_status_update("Loading transformer")
subfolder = 'transformer'
transformer_path = model_path
local_files_only = False
@@ -688,7 +690,7 @@ class StableDiffusion:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
print_acc("Quantizing transformer")
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
@@ -698,7 +700,7 @@ class StableDiffusion:
flush()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print_acc("Loading vae")
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
@@ -707,7 +709,7 @@ class StableDiffusion:
text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
else:
print_acc("Loading t5")
self.print_and_status_update("Loading T5")
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
torch_dtype=dtype)
@@ -717,19 +719,19 @@ class StableDiffusion:
if self.model_config.quantize_te:
if self.is_flex2:
print_acc("Quantizing LLM")
self.print_and_status_update("Quantizing LLM")
else:
print_acc("Quantizing T5")
self.print_and_status_update("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
flush()
print_acc("Loading clip")
self.print_and_status_update("Loading CLIP")
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
print_acc("making pipe")
self.print_and_status_update("Making pipe")
Pipe = FluxPipeline
if self.is_flex2:
Pipe = Flex2Pipeline
@@ -746,7 +748,7 @@ class StableDiffusion:
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
print_acc("preparing")
self.print_and_status_update("Preparing Model")
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
@@ -763,10 +765,10 @@ class StableDiffusion:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
elif self.model_config.is_lumina2:
print_acc("Loading Lumina2 model")
self.print_and_status_update("Loading Lumina2 model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
print_acc("Loading transformer")
self.print_and_status_update("Loading transformer")
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
@@ -802,7 +804,7 @@ class StableDiffusion:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
print_acc("Quantizing transformer")
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
freeze(transformer)
transformer.to(self.device_torch)
@@ -812,16 +814,16 @@ class StableDiffusion:
flush()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print_acc("Loading vae")
self.print_and_status_update("Loading vae")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
if self.model_config.te_name_or_path is not None:
print_acc("Loading TE")
self.print_and_status_update("Loading TE")
tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
else:
print_acc("Loading Gemma2")
self.print_and_status_update("Loading Gemma2")
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
@@ -829,12 +831,12 @@ class StableDiffusion:
flush()
if self.model_config.quantize_te:
print_acc("Quantizing Gemma2")
self.print_and_status_update("Quantizing Gemma2")
quantize(text_encoder, weights=qfloat8)
freeze(text_encoder)
flush()
print_acc("making pipe")
self.print_and_status_update("Making pipe")
pipe: Lumina2Text2ImgPipeline = Lumina2Text2ImgPipeline(
scheduler=scheduler,
text_encoder=None,
@@ -845,7 +847,7 @@ class StableDiffusion:
pipe.text_encoder = text_encoder
pipe.transformer = transformer
print_acc("preparing")
self.print_and_status_update("Preparing Model")
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer
@@ -1032,6 +1034,25 @@ class StableDiffusion:
self.refiner_unet = refiner.unet
del refiner
flush()
def _after_sample_image(self, img_num, total_imgs):
# process all hooks
for hook in self._after_sample_img_hooks:
hook(img_num, total_imgs)
def add_after_sample_image_hook(self, func):
self._after_sample_img_hooks.append(func)
def _status_update(self, status: str):
for hook in self._status_update_hooks:
hook(status)
def print_and_status_update(self, status: str):
print_acc(status)
self._status_update(status)
def add_status_update_hook(self, func):
self._status_update_hooks.append(func)
@torch.no_grad()
def generate_images(
@@ -1598,6 +1619,7 @@ class StableDiffusion:
gen_config.save_image(img, i)
gen_config.log_image(img, i)
self._after_sample_image(i, len(image_configs))
flush()
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):

View File

@@ -9,6 +9,7 @@ class Timer:
self.timers = OrderedDict()
self.active_timers = {}
self.current_timer = None # Used for the context manager functionality
self._after_print_hooks = []
def start(self, timer_name):
if timer_name not in self.timers:
@@ -34,12 +35,20 @@ class Timer:
if len(self.timers[timer_name]) > self.max_buffer:
self.timers[timer_name].popleft()
def add_after_print_hook(self, hook):
self._after_print_hooks.append(hook)
def print(self):
print(f"\nTimer '{self.name}':")
timing_dict = {}
# sort by longest at top
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
avg_time = sum(timings) / len(timings)
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
timing_dict[timer_name] = avg_time
for hook in self._after_print_hooks:
hook(timing_dict)
print('')

42
ui/.gitignore vendored Normal file
View File

@@ -0,0 +1,42 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.*
.yarn/*
!.yarn/patches
!.yarn/plugins
!.yarn/releases
!.yarn/versions
# testing
/coverage
# next.js
/.next/
/out/
# production
/build
# misc
.DS_Store
*.pem
# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.pnpm-debug.log*
# env files (can opt-in for committing if needed)
.env*
# vercel
.vercel
# typescript
*.tsbuildinfo
next-env.d.ts
aitk_db.db

36
ui/README.md Normal file
View File

@@ -0,0 +1,36 @@
This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
## Getting Started
First, run the development server:
```bash
npm run dev
# or
yarn dev
# or
pnpm dev
# or
bun dev
```
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
## Learn More
To learn more about Next.js, take a look at the following resources:
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
## Deploy on Vercel
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.

15
ui/next.config.ts Normal file
View File

@@ -0,0 +1,15 @@
import type { NextConfig } from 'next';
const nextConfig: NextConfig = {
typescript: {
// Remove this. Build fails because of route types
ignoreBuildErrors: true,
},
experimental: {
serverActions: {
bodySizeLimit: '100mb',
},
},
};
export default nextConfig;

4516
ui/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

40
ui/package.json Normal file
View File

@@ -0,0 +1,40 @@
{
"name": "ai-toolkit-ui",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "next dev --turbopack",
"build": "next build",
"start": "next start --port 8675",
"lint": "next lint",
"update_db": "npx prisma generate ; npx prisma db push",
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
},
"dependencies": {
"@headlessui/react": "^2.2.0",
"@prisma/client": "^6.3.1",
"axios": "^1.7.9",
"classnames": "^2.5.1",
"lucide-react": "^0.475.0",
"next": "15.1.7",
"node-cache": "^5.1.2",
"prisma": "^6.3.1",
"react": "^19.0.0",
"react-dom": "^19.0.0",
"react-dropzone": "^14.3.5",
"react-global-hooks": "^1.3.5",
"react-icons": "^5.5.0",
"sqlite3": "^5.1.7"
},
"devDependencies": {
"@types/node": "^20",
"@types/react": "^19",
"@types/react-dom": "^19",
"postcss": "^8",
"prettier": "^3.5.1",
"prettier-basic": "^1.0.0",
"tailwindcss": "^3.4.1",
"typescript": "^5"
},
"prettier": "prettier-basic"
}

8
ui/postcss.config.mjs Normal file
View File

@@ -0,0 +1,8 @@
/** @type {import('postcss-load-config').Config} */
const config = {
plugins: {
tailwindcss: {},
},
};
export default config;

28
ui/prisma/schema.prisma Normal file
View File

@@ -0,0 +1,28 @@
generator client {
provider = "prisma-client-js"
}
datasource db {
provider = "sqlite"
url = "file:../../aitk_db.db"
}
model Settings {
id Int @id @default(autoincrement())
key String @unique
value String
}
model Job {
id String @id @default(uuid())
name String @unique
gpu_ids String
job_config String // JSON string
created_at DateTime @default(now())
updated_at DateTime @updatedAt
status String @default("stopped")
stop Boolean @default(false)
step Int @default(0)
info String @default("")
speed_string String @default("")
}

1
ui/public/file.svg Normal file
View File

@@ -0,0 +1 @@
<svg fill="none" viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg"><path d="M14.5 13.5V5.41a1 1 0 0 0-.3-.7L9.8.29A1 1 0 0 0 9.08 0H1.5v13.5A2.5 2.5 0 0 0 4 16h8a2.5 2.5 0 0 0 2.5-2.5m-1.5 0v-7H8v-5H3v12a1 1 0 0 0 1 1h8a1 1 0 0 0 1-1M9.5 5V2.12L12.38 5zM5.13 5h-.62v1.25h2.12V5zm-.62 3h7.12v1.25H4.5zm.62 3h-.62v1.25h7.12V11z" clip-rule="evenodd" fill="#666" fill-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 391 B

1
ui/public/globe.svg Normal file
View File

@@ -0,0 +1 @@
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><g clip-path="url(#a)"><path fill-rule="evenodd" clip-rule="evenodd" d="M10.27 14.1a6.5 6.5 0 0 0 3.67-3.45q-1.24.21-2.7.34-.31 1.83-.97 3.1M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16m.48-1.52a7 7 0 0 1-.96 0H7.5a4 4 0 0 1-.84-1.32q-.38-.89-.63-2.08a40 40 0 0 0 3.92 0q-.25 1.2-.63 2.08a4 4 0 0 1-.84 1.31zm2.94-4.76q1.66-.15 2.95-.43a7 7 0 0 0 0-2.58q-1.3-.27-2.95-.43a18 18 0 0 1 0 3.44m-1.27-3.54a17 17 0 0 1 0 3.64 39 39 0 0 1-4.3 0 17 17 0 0 1 0-3.64 39 39 0 0 1 4.3 0m1.1-1.17q1.45.13 2.69.34a6.5 6.5 0 0 0-3.67-3.44q.65 1.26.98 3.1M8.48 1.5l.01.02q.41.37.84 1.31.38.89.63 2.08a40 40 0 0 0-3.92 0q.25-1.2.63-2.08a4 4 0 0 1 .85-1.32 7 7 0 0 1 .96 0m-2.75.4a6.5 6.5 0 0 0-3.67 3.44 29 29 0 0 1 2.7-.34q.31-1.83.97-3.1M4.58 6.28q-1.66.16-2.95.43a7 7 0 0 0 0 2.58q1.3.27 2.95.43a18 18 0 0 1 0-3.44m.17 4.71q-1.45-.12-2.69-.34a6.5 6.5 0 0 0 3.67 3.44q-.65-1.27-.98-3.1" fill="#666"/></g><defs><clipPath id="a"><path fill="#fff" d="M0 0h16v16H0z"/></clipPath></defs></svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

1
ui/public/next.svg Normal file
View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 394 80"><path fill="#000" d="M262 0h68.5v12.7h-27.2v66.6h-13.6V12.7H262V0ZM149 0v12.7H94v20.4h44.3v12.6H94v21h55v12.6H80.5V0h68.7zm34.3 0h-17.8l63.8 79.4h17.9l-32-39.7 32-39.6h-17.9l-23 28.6-23-28.6zm18.3 56.7-9-11-27.1 33.7h17.8l18.3-22.7z"/><path fill="#000" d="M81 79.3 17 0H0v79.3h13.6V17l50.2 62.3H81Zm252.6-.4c-1 0-1.8-.4-2.5-1s-1.1-1.6-1.1-2.6.3-1.8 1-2.5 1.6-1 2.6-1 1.8.3 2.5 1a3.4 3.4 0 0 1 .6 4.3 3.7 3.7 0 0 1-3 1.8zm23.2-33.5h6v23.3c0 2.1-.4 4-1.3 5.5a9.1 9.1 0 0 1-3.8 3.5c-1.6.8-3.5 1.3-5.7 1.3-2 0-3.7-.4-5.3-1s-2.8-1.8-3.7-3.2c-.9-1.3-1.4-3-1.4-5h6c.1.8.3 1.6.7 2.2s1 1.2 1.6 1.5c.7.4 1.5.5 2.4.5 1 0 1.8-.2 2.4-.6a4 4 0 0 0 1.6-1.8c.3-.8.5-1.8.5-3V45.5zm30.9 9.1a4.4 4.4 0 0 0-2-3.3 7.5 7.5 0 0 0-4.3-1.1c-1.3 0-2.4.2-3.3.5-.9.4-1.6 1-2 1.6a3.5 3.5 0 0 0-.3 4c.3.5.7.9 1.3 1.2l1.8 1 2 .5 3.2.8c1.3.3 2.5.7 3.7 1.2a13 13 0 0 1 3.2 1.8 8.1 8.1 0 0 1 3 6.5c0 2-.5 3.7-1.5 5.1a10 10 0 0 1-4.4 3.5c-1.8.8-4.1 1.2-6.8 1.2-2.6 0-4.9-.4-6.8-1.2-2-.8-3.4-2-4.5-3.5a10 10 0 0 1-1.7-5.6h6a5 5 0 0 0 3.5 4.6c1 .4 2.2.6 3.4.6 1.3 0 2.5-.2 3.5-.6 1-.4 1.8-1 2.4-1.7a4 4 0 0 0 .8-2.4c0-.9-.2-1.6-.7-2.2a11 11 0 0 0-2.1-1.4l-3.2-1-3.8-1c-2.8-.7-5-1.7-6.6-3.2a7.2 7.2 0 0 1-2.4-5.7 8 8 0 0 1 1.7-5 10 10 0 0 1 4.3-3.5c2-.8 4-1.2 6.4-1.2 2.3 0 4.4.4 6.2 1.2 1.8.8 3.2 2 4.3 3.4 1 1.4 1.5 3 1.5 5h-5.8z"/></svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

BIN
ui/public/ostris_logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

1
ui/public/vercel.svg Normal file
View File

@@ -0,0 +1 @@
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1155 1000"><path d="m577.3 0 577.4 1000H0z" fill="#fff"/></svg>

After

Width:  |  Height:  |  Size: 128 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

1
ui/public/window.svg Normal file
View File

@@ -0,0 +1 @@
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path fill-rule="evenodd" clip-rule="evenodd" d="M1.5 2.5h13v10a1 1 0 0 1-1 1h-11a1 1 0 0 1-1-1zM0 1h16v11.5a2.5 2.5 0 0 1-2.5 2.5h-11A2.5 2.5 0 0 1 0 12.5zm3.75 4.5a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5M7 4.75a.75.75 0 1 1-1.5 0 .75.75 0 0 1 1.5 0m1.75.75a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5" fill="#666"/></svg>

After

Width:  |  Height:  |  Size: 385 B

View File

@@ -0,0 +1,42 @@
/* eslint-disable */
import { NextRequest, NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/server/settings';
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
const { imagePath } = await params;
try {
// Decode the path
const filepath = decodeURIComponent(imagePath);
// caption name is the filepath without extension but with .txt
const captionPath = filepath.replace(/\.[^/.]+$/, '') + '.txt';
// Get allowed directories
const allowedDir = await getDatasetsRoot();
// Security check: Ensure path is in allowed directory
const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..');
if (!isAllowed) {
console.warn(`Access denied: ${filepath} not in ${allowedDir}`);
return new NextResponse('Access denied', { status: 403 });
}
// Check if file exists
if (!fs.existsSync(captionPath)) {
// send back blank string if caption file does not exist
return new NextResponse('');
}
// Read caption file
const caption = fs.readFileSync(captionPath, 'utf-8');
// Return caption
return new NextResponse(caption);
} catch (error) {
console.error('Error getting caption:', error);
return new NextResponse('Error getting caption', { status: 500 });
}
}

View File

@@ -0,0 +1,22 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
try {
const body = await request.json();
const { name } = body;
let datasetsPath = await getDatasetsRoot();
let datasetPath = path.join(datasetsPath, name);
// if folder doesnt exist, create it
if (!fs.existsSync(datasetPath)) {
fs.mkdirSync(datasetPath);
}
return NextResponse.json({ success: true });
} catch (error) {
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
}
}

View File

@@ -0,0 +1,24 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
try {
const body = await request.json();
const { name } = body;
let datasetsPath = await getDatasetsRoot();
let datasetPath = path.join(datasetsPath, name);
// if folder doesnt exist, ignore
if (!fs.existsSync(datasetPath)) {
return NextResponse.json({ success: true });
}
// delete it and return success
fs.rmdirSync(datasetPath, { recursive: true });
return NextResponse.json({ success: true });
} catch (error) {
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
}
}

View File

@@ -0,0 +1,25 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import { getDatasetsRoot } from '@/server/settings';
export async function GET() {
try {
let datasetsPath = await getDatasetsRoot();
// if folder doesnt exist, create it
if (!fs.existsSync(datasetsPath)) {
fs.mkdirSync(datasetsPath);
}
// find all the folders in the datasets folder
let folders = fs
.readdirSync(datasetsPath, { withFileTypes: true })
.filter(dirent => dirent.isDirectory())
.filter(dirent => !dirent.name.startsWith('.'))
.map(dirent => dirent.name);
return NextResponse.json(folders);
} catch (error) {
return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: 500 });
}
}

View File

@@ -0,0 +1,67 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
const datasetsPath = await getDatasetsRoot();
const body = await request.json();
const { datasetName } = body;
const datasetFolder = path.join(datasetsPath, datasetName);
try {
// Check if folder exists
if (!fs.existsSync(datasetFolder)) {
return NextResponse.json(
{ error: `Folder '${datasetName}' not found` },
{ status: 404 }
);
}
// Find all images recursively
const imageFiles = findImagesRecursively(datasetFolder);
// Format response
const result = imageFiles.map(imgPath => ({
img_path: imgPath
}));
return NextResponse.json({ images: result });
} catch (error) {
console.error('Error finding images:', error);
return NextResponse.json(
{ error: 'Failed to process request' },
{ status: 500 }
);
}
}
/**
* Recursively finds all image files in a directory and its subdirectories
* @param dir Directory to search
* @returns Array of absolute paths to image files
*/
function findImagesRecursively(dir: string): string[] {
const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp'];
let results: string[] = [];
const items = fs.readdirSync(dir);
for (const item of items) {
const itemPath = path.join(dir, item);
const stat = fs.statSync(itemPath);
if (stat.isDirectory()) {
// If it's a directory, recursively search it
results = results.concat(findImagesRecursively(itemPath));
} else {
// If it's a file, check if it's an image
const ext = path.extname(itemPath).toLowerCase();
if (imageExtensions.includes(ext)) {
results.push(itemPath);
}
}
}
return results;
}

View File

@@ -0,0 +1,55 @@
// src/app/api/datasets/upload/route.ts
import { NextRequest, NextResponse } from 'next/server';
import { writeFile, mkdir } from 'fs/promises';
import { join } from 'path';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: NextRequest) {
try {
const datasetsPath = await getDatasetsRoot();
if (!datasetsPath) {
return NextResponse.json({ error: 'Datasets path not found' }, { status: 500 });
}
const formData = await request.formData();
const files = formData.getAll('files');
const datasetName = formData.get('datasetName') as string;
if (!files || files.length === 0) {
return NextResponse.json({ error: 'No files provided' }, { status: 400 });
}
// Create upload directory if it doesn't exist
const uploadDir = join(datasetsPath, datasetName);
await mkdir(uploadDir, { recursive: true });
const savedFiles = await Promise.all(
files.map(async (file: any) => {
const bytes = await file.arrayBuffer();
const buffer = Buffer.from(bytes);
// Clean filename and ensure it's unique
const fileName = file.name.replace(/[^a-zA-Z0-9.-]/g, '_');
const filePath = join(uploadDir, fileName);
await writeFile(filePath, buffer);
return fileName;
}),
);
return NextResponse.json({
message: 'Files uploaded successfully',
files: savedFiles,
});
} catch (error) {
console.error('Upload error:', error);
return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
}
}
// Increase payload size limit (default is 4mb)
export const config = {
api: {
bodyParser: false,
responseLimit: '50mb',
},
};

View File

@@ -0,0 +1,106 @@
/* eslint-disable */
import { NextRequest, NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) {
const { filePath } = await params;
try {
// Decode the path
const decodedFilePath = decodeURIComponent(filePath);
// Get allowed directories
const datasetRoot = await getDatasetsRoot();
const trainingRoot = await getTrainingFolder();
const allowedDirs = [datasetRoot, trainingRoot];
// Security check: Ensure path is in allowed directory
const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
if (!isAllowed) {
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
return new NextResponse('Access denied', { status: 403 });
}
// Check if file exists
if (!fs.existsSync(decodedFilePath)) {
console.warn(`File not found: ${decodedFilePath}`);
return new NextResponse('File not found', { status: 404 });
}
// Get file info
const stat = fs.statSync(decodedFilePath);
if (!stat.isFile()) {
return new NextResponse('Not a file', { status: 400 });
}
// Get filename for Content-Disposition
const filename = path.basename(decodedFilePath);
// Determine content type
const ext = path.extname(decodedFilePath).toLowerCase();
const contentTypeMap: { [key: string]: string } = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp',
'.svg': 'image/svg+xml',
'.bmp': 'image/bmp',
'.safetensors': 'application/octet-stream',
};
const contentType = contentTypeMap[ext] || 'application/octet-stream';
// Get range header for partial content support
const range = request.headers.get('range');
// Common headers for better download handling
const commonHeaders = {
'Content-Type': contentType,
'Accept-Ranges': 'bytes',
'Cache-Control': 'public, max-age=86400',
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
'X-Content-Type-Options': 'nosniff'
};
if (range) {
// Parse range header
const parts = range.replace(/bytes=/, '').split('-');
const start = parseInt(parts[0], 10);
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
const chunkSize = (end - start) + 1;
const fileStream = fs.createReadStream(decodedFilePath, {
start,
end,
highWaterMark: 64 * 1024 // 64KB buffer
});
return new NextResponse(fileStream as any, {
status: 206,
headers: {
...commonHeaders,
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
'Content-Length': String(chunkSize)
},
});
} else {
// For full file download, read directly without streaming wrapper
const fileStream = fs.createReadStream(decodedFilePath, {
highWaterMark: 64 * 1024 // 64KB buffer
});
return new NextResponse(fileStream as any, {
headers: {
...commonHeaders,
'Content-Length': String(stat.size)
},
});
}
} catch (error) {
console.error('Error serving file:', error);
return new NextResponse('Internal Server Error', { status: 500 });
}
}

106
ui/src/app/api/gpu/route.ts Normal file
View File

@@ -0,0 +1,106 @@
import { NextResponse } from 'next/server';
import { exec } from 'child_process';
import { promisify } from 'util';
const execAsync = promisify(exec);
export async function GET() {
try {
// Check if nvidia-smi is available
const hasNvidiaSmi = await checkNvidiaSmi();
if (!hasNvidiaSmi) {
return NextResponse.json({
hasNvidiaSmi: false,
gpus: [],
error: 'nvidia-smi not found or not accessible',
});
}
// Get GPU stats
const gpuStats = await getGpuStats();
return NextResponse.json({
hasNvidiaSmi: true,
gpus: gpuStats,
});
} catch (error) {
console.error('Error fetching NVIDIA GPU stats:', error);
return NextResponse.json(
{
hasNvidiaSmi: false,
gpus: [],
error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`,
},
{ status: 500 },
);
}
}
async function checkNvidiaSmi(): Promise<boolean> {
try {
await execAsync('which nvidia-smi');
return true;
} catch (error) {
return false;
}
}
async function getGpuStats() {
// Get detailed GPU information in JSON format including fan speed
const { stdout } = await execAsync(
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits',
);
// Parse CSV output
const gpus = stdout
.trim()
.split('\n')
.map(line => {
const [
index,
name,
driverVersion,
temperature,
gpuUtil,
memoryUtil,
memoryTotal,
memoryFree,
memoryUsed,
powerDraw,
powerLimit,
clockGraphics,
clockMemory,
fanSpeed,
] = line.split(', ').map(item => item.trim());
return {
index: parseInt(index),
name,
driverVersion,
temperature: parseInt(temperature),
utilization: {
gpu: parseInt(gpuUtil),
memory: parseInt(memoryUtil),
},
memory: {
total: parseInt(memoryTotal),
free: parseInt(memoryFree),
used: parseInt(memoryUsed),
},
power: {
draw: parseFloat(powerDraw),
limit: parseFloat(powerLimit),
},
clocks: {
graphics: parseInt(clockGraphics),
memory: parseInt(clockMemory),
},
fan: {
speed: parseInt(fanSpeed), // Fan speed as percentage
},
};
});
return gpus;
}

View File

@@ -0,0 +1,68 @@
/* eslint-disable */
import { NextRequest, NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
const { imagePath } = await params;
try {
// Decode the path
const filepath = decodeURIComponent(imagePath);
// Get allowed directories
const datasetRoot = await getDatasetsRoot();
const trainingRoot = await getTrainingFolder();
const allowedDirs = [datasetRoot, trainingRoot];
// Security check: Ensure path is in allowed directory
const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..');
if (!isAllowed) {
console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`);
return new NextResponse('Access denied', { status: 403 });
}
// Check if file exists
if (!fs.existsSync(filepath)) {
console.warn(`File not found: ${filepath}`);
return new NextResponse('File not found', { status: 404 });
}
// Get file info
const stat = fs.statSync(filepath);
if (!stat.isFile()) {
return new NextResponse('Not a file', { status: 400 });
}
// Determine content type
const ext = path.extname(filepath).toLowerCase();
const contentTypeMap: { [key: string]: string } = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp',
'.svg': 'image/svg+xml',
'.bmp': 'image/bmp',
};
const contentType = contentTypeMap[ext] || 'application/octet-stream';
// Read file as buffer
const fileBuffer = fs.readFileSync(filepath);
// Return file with appropriate headers
return new NextResponse(fileBuffer, {
headers: {
'Content-Type': contentType,
'Content-Length': String(stat.size),
'Cache-Control': 'public, max-age=86400',
},
});
} catch (error) {
console.error('Error serving image:', error);
return new NextResponse('Internal Server Error', { status: 500 });
}
}

View File

@@ -0,0 +1,30 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
try {
const body = await request.json();
const { imgPath, caption } = body;
let datasetsPath = await getDatasetsRoot();
// make sure the dataset path is in the image path
if (!imgPath.startsWith(datasetsPath)) {
return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
}
// if img doesnt exist, ignore
if (!fs.existsSync(imgPath)) {
return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
}
// check for caption
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
// save caption to file
fs.writeFileSync(captionPath, caption);
return NextResponse.json({ success: true });
} catch (error) {
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
}
}

View File

@@ -0,0 +1,34 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
try {
const body = await request.json();
const { imgPath } = body;
let datasetsPath = await getDatasetsRoot();
// make sure the dataset path is in the image path
if (!imgPath.startsWith(datasetsPath)) {
return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
}
// if img doesnt exist, ignore
if (!fs.existsSync(imgPath)) {
return NextResponse.json({ success: true });
}
// delete it and return success
fs.unlinkSync(imgPath);
// check for caption
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
if (fs.existsSync(captionPath)) {
// delete caption file
fs.unlinkSync(captionPath);
}
return NextResponse.json({ success: true });
} catch (error) {
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
}
}

View File

@@ -0,0 +1,32 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import { getTrainingFolder } from '@/server/settings';
import path from 'path';
import fs from 'fs';
const prisma = new PrismaClient();
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
const { jobID } = await params;
const job = await prisma.job.findUnique({
where: { id: jobID },
});
if (!job) {
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
}
const trainingRoot = await getTrainingFolder();
const trainingFolder = path.join(trainingRoot, job.name);
if (fs.existsSync(trainingFolder)) {
fs.rmdirSync(trainingFolder, { recursive: true });
}
await prisma.job.delete({
where: { id: jobID },
});
return NextResponse.json(job);
}

View File

@@ -0,0 +1,48 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import path from 'path';
import fs from 'fs';
import { getTrainingFolder } from '@/server/settings';
const prisma = new PrismaClient();
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
const { jobID } = await params;
const job = await prisma.job.findUnique({
where: { id: jobID },
});
if (!job) {
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
}
const trainingFolder = await getTrainingFolder();
const jobFolder = path.join(trainingFolder, job.name);
if (!fs.existsSync(jobFolder)) {
return NextResponse.json({ files: [] });
}
// find all img (png, jpg, jpeg) files in the samples folder
let files = fs
.readdirSync(jobFolder)
.filter(file => {
return file.endsWith('.safetensors');
})
.map(file => {
return path.join(jobFolder, file);
})
.sort();
// get the file size for each file
const fileObjects = files.map(file => {
const stats = fs.statSync(file);
return {
path: file,
size: stats.size,
};
});
return NextResponse.json({ files: fileObjects });
}

View File

@@ -0,0 +1,40 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import path from 'path';
import fs from 'fs';
import { getTrainingFolder } from '@/server/settings';
const prisma = new PrismaClient();
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
const { jobID } = await params;
const job = await prisma.job.findUnique({
where: { id: jobID },
});
if (!job) {
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
}
// setup the training
const trainingFolder = await getTrainingFolder();
const samplesFolder = path.join(trainingFolder, job.name, 'samples');
if (!fs.existsSync(samplesFolder)) {
return NextResponse.json({ samples: [] });
}
// find all img (png, jpg, jpeg) files in the samples folder
const samples = fs
.readdirSync(samplesFolder)
.filter(file => {
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg');
})
.map(file => {
return path.join(samplesFolder, file);
})
.sort();
return NextResponse.json({ samples });
}

View File

@@ -0,0 +1,93 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
import { spawn } from 'child_process';
import path from 'path';
import fs from 'fs';
import { getTrainingFolder, getHFToken } from '@/server/settings';
const prisma = new PrismaClient();
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
const { jobID } = await params;
const job = await prisma.job.findUnique({
where: { id: jobID },
});
if (!job) {
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
}
// update job status to 'running'
await prisma.job.update({
where: { id: jobID },
data: {
status: 'running',
stop: false,
info: 'Starting job...',
},
});
// setup the training
const trainingRoot = await getTrainingFolder();
const trainingFolder = path.join(trainingRoot, job.name);
if (!fs.existsSync(trainingFolder)) {
fs.mkdirSync(trainingFolder, { recursive: true });
}
// make the config file
const configPath = path.join(trainingFolder, '.job_config.json');
// update the config dataset path
const jobConfig = JSON.parse(job.job_config);
jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
// write the config file
fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
let pythonPath = 'python';
// use .venv or venv if it exists
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
}
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
if (!fs.existsSync(runFilePath)) {
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
}
const additionalEnv: any = {
AITK_JOB_ID: jobID,
CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`,
};
// HF_TOKEN
const hfToken = await getHFToken();
if (hfToken && hfToken.trim() !== '') {
additionalEnv.HF_TOKEN = hfToken;
}
// console.log(
// 'Spawning command:',
// `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`,
// );
// start job
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
detached: true,
stdio: 'ignore',
env: {
...process.env,
...additionalEnv,
},
cwd: TOOLKIT_ROOT,
});
subprocess.unref();
return NextResponse.json(job);
}

View File

@@ -0,0 +1,23 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
const prisma = new PrismaClient();
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
const { jobID } = await params;
const job = await prisma.job.findUnique({
where: { id: jobID },
});
// update job status to 'running'
await prisma.job.update({
where: { id: jobID },
data: {
stop: true,
info: 'Stopping job...',
},
});
return NextResponse.json(job);
}

View File

@@ -0,0 +1,58 @@
import { NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
const prisma = new PrismaClient();
export async function GET(request: Request) {
const { searchParams } = new URL(request.url);
const id = searchParams.get('id');
try {
if (id) {
const job = await prisma.job.findUnique({
where: { id },
});
return NextResponse.json(job);
}
const jobs = await prisma.job.findMany({
orderBy: { created_at: 'desc' },
});
return NextResponse.json({ jobs: jobs });
} catch (error) {
console.error(error);
return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
}
}
export async function POST(request: Request) {
try {
const body = await request.json();
const { id, name, job_config, gpu_ids } = body;
if (id) {
// Update existing training
const training = await prisma.job.update({
where: { id },
data: {
name,
gpu_ids,
job_config: JSON.stringify(job_config),
},
});
return NextResponse.json(training);
} else {
// Create new training
const training = await prisma.job.create({
data: {
name,
gpu_ids,
job_config: JSON.stringify(job_config),
},
});
return NextResponse.json(training);
}
} catch (error) {
return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 });
}
}

View File

@@ -0,0 +1,59 @@
import { NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
import {flushCache} from '@/server/settings';
const prisma = new PrismaClient();
export async function GET() {
try {
const settings = await prisma.settings.findMany();
const settingsObject = settings.reduce((acc: any, setting) => {
acc[setting.key] = setting.value;
return acc;
}, {});
// if TRAINING_FOLDER is not set, use default
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
}
// if DATASETS_FOLDER is not set, use default
if (!settingsObject.DATASETS_FOLDER || settingsObject.DATASETS_FOLDER === '') {
settingsObject.DATASETS_FOLDER = defaultDatasetsFolder;
}
return NextResponse.json(settingsObject);
} catch (error) {
return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 });
}
}
export async function POST(request: Request) {
try {
const body = await request.json();
const { HF_TOKEN, TRAINING_FOLDER, DATASETS_FOLDER } = body;
// Upsert both settings
await Promise.all([
prisma.settings.upsert({
where: { key: 'HF_TOKEN' },
update: { value: HF_TOKEN },
create: { key: 'HF_TOKEN', value: HF_TOKEN },
}),
prisma.settings.upsert({
where: { key: 'TRAINING_FOLDER' },
update: { value: TRAINING_FOLDER },
create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER },
}),
prisma.settings.upsert({
where: { key: 'DATASETS_FOLDER' },
update: { value: DATASETS_FOLDER },
create: { key: 'DATASETS_FOLDER', value: DATASETS_FOLDER },
}),
]);
flushCache();
return NextResponse.json({ success: true });
} catch (error) {
return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 });
}
}

BIN
ui/src/app/apple-icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.7 KiB

View File

@@ -0,0 +1,31 @@
'use client';
import GpuMonitor from '@/components/GPUMonitor';
import JobsTable from '@/components/JobsTable';
import { TopBar, MainContent } from '@/components/layout';
import Link from 'next/link';
export default function Dashboard() {
return (
<>
<TopBar>
<div>
<h1 className="text-lg">Dashboard</h1>
</div>
<div className="flex-1"></div>
</TopBar>
<MainContent>
<GpuMonitor />
<div className="w-full mt-4">
<div className="flex justify-between items-center mb-2">
<h1 className="text-md">Active Jobs</h1>
<div className="text-xs text-gray-500">
<Link href="/jobs">View All</Link>
</div>
</div>
<JobsTable onlyActive />
</div>
</MainContent>
</>
);
}

View File

@@ -0,0 +1,86 @@
'use client';
import { useEffect, useState, use } from 'react';
import { FaChevronLeft } from 'react-icons/fa';
import DatasetImageCard from '@/components/DatasetImageCard';
import { Button } from '@headlessui/react';
import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
import { TopBar, MainContent } from '@/components/layout';
export default function DatasetPage({ params }: { params: { datasetName: string } }) {
const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
const usableParams = use(params as any) as { datasetName: string };
const datasetName = usableParams.datasetName;
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
const refreshImageList = (dbName: string) => {
setStatus('loading');
fetch('/api/datasets/listImages', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ datasetName: dbName }),
})
.then(res => res.json())
.then(data => {
console.log('Images:', data.images);
// sort
data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));
setImgList(data.images);
setStatus('success');
})
.catch(error => {
console.error('Error fetching images:', error);
setStatus('error');
});
};
useEffect(() => {
if (datasetName) {
refreshImageList(datasetName);
}
}, [datasetName]);
return (
<>
{/* Fixed top bar */}
<TopBar>
<div>
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
<FaChevronLeft />
</Button>
</div>
<div>
<h1 className="text-lg">Dataset: {datasetName}</h1>
</div>
<div className="flex-1"></div>
<div>
<Button
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
onClick={() => openImagesModal(datasetName, () => refreshImageList(datasetName))}
>
Add Images
</Button>
</div>
</TopBar>
<MainContent>
{status === 'loading' && <p>Loading...</p>}
{status === 'error' && <p>Error fetching images</p>}
{status === 'success' && (
<div className="grid grid-cols-1 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 gap-4">
{imgList.length === 0 && <p>No images found</p>}
{imgList.map(img => (
<DatasetImageCard
key={img.img_path}
alt="image"
imageUrl={img.img_path}
onDelete={() => refreshImageList(datasetName)}
/>
))}
</div>
)}
</MainContent>
<AddImagesModal />
</>
);
}

View File

@@ -0,0 +1,157 @@
'use client';
import { useState } from 'react';
import { Modal } from '@/components/Modal';
import Link from 'next/link';
import { TextInput } from '@/components/formInputs';
import useDatasetList from '@/hooks/useDatasetList';
import { Button } from '@headlessui/react';
import { FaRegTrashAlt } from 'react-icons/fa';
import { openConfirm } from '@/components/ConfirmModal';
import { TopBar, MainContent } from '@/components/layout';
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
export default function Datasets() {
const { datasets, status, refreshDatasets } = useDatasetList();
const [newDatasetName, setNewDatasetName] = useState('');
const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false);
// Transform datasets array into rows with objects
const tableRows = datasets.map(dataset => ({
name: dataset,
actions: dataset, // Pass full dataset name for actions
}));
const columns: TableColumn[] = [
{
title: 'Dataset Name',
key: 'name',
render: row => (
<Link href={`/datasets/${row.name}`} className="text-gray-200 hover:text-gray-100">
{row.name}
</Link>
),
},
{
title: 'Actions',
key: 'actions',
className: 'w-20 text-right',
render: row => (
<button
className="text-gray-200 hover:bg-red-600 p-2 rounded-full transition-colors"
onClick={() => handleDeleteDataset(row.name)}
>
<FaRegTrashAlt />
</button>
),
},
];
const handleDeleteDataset = (datasetName: string) => {
openConfirm({
title: 'Delete Dataset',
message: `Are you sure you want to delete the dataset "${datasetName}"? This action cannot be undone.`,
type: 'warning',
confirmText: 'Delete',
onConfirm: () => {
fetch('/api/datasets/delete', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ name: datasetName }),
})
.then(res => res.json())
.then(data => {
console.log('Dataset deleted:', data);
refreshDatasets();
})
.catch(error => {
console.error('Error deleting dataset:', error);
});
},
});
};
const handleCreateDataset = async (e: React.FormEvent) => {
e.preventDefault();
try {
const response = await fetch('/api/datasets/create', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ name: newDatasetName }),
});
const data = await response.json();
console.log('New dataset created:', data);
refreshDatasets();
setNewDatasetName('');
setIsNewDatasetModalOpen(false);
} catch (error) {
console.error('Error creating new dataset:', error);
}
};
return (
<>
<TopBar>
<div>
<h1 className="text-2xl font-semibold text-gray-100">Datasets</h1>
</div>
<div className="flex-1"></div>
<div>
<Button
className="text-gray-200 bg-slate-600 px-4 py-2 rounded-md hover:bg-slate-500 transition-colors"
onClick={() => setIsNewDatasetModalOpen(true)}
>
New Dataset
</Button>
</div>
</TopBar>
<MainContent>
<UniversalTable
columns={columns}
rows={tableRows}
isLoading={status === 'loading'}
onRefresh={refreshDatasets}
/>
</MainContent>
<Modal
isOpen={isNewDatasetModalOpen}
onClose={() => setIsNewDatasetModalOpen(false)}
title="New Dataset"
size="md"
>
<div className="space-y-4 text-gray-200">
<form onSubmit={handleCreateDataset}>
<div className="text-sm text-gray-400">
This will create a new folder with the name below in your dataset folder.
</div>
<div className="mt-4">
<TextInput label="Dataset Name" value={newDatasetName} onChange={value => setNewDatasetName(value)} />
</div>
<div className="mt-6 flex justify-end space-x-3">
<button
type="button"
className="rounded-md bg-gray-700 px-4 py-2 text-gray-200 hover:bg-gray-600 focus:outline-none focus:ring-2 focus:ring-gray-500"
onClick={() => setIsNewDatasetModalOpen(false)}
>
Cancel
</button>
<button
type="submit"
className="rounded-md bg-blue-600 px-4 py-2 text-white hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500"
>
Confirm
</button>
</div>
</form>
</div>
</Modal>
</>
);
}

BIN
ui/src/app/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

21
ui/src/app/globals.css Normal file
View File

@@ -0,0 +1,21 @@
@tailwind base;
@tailwind components;
@tailwind utilities;
:root {
--background: #ffffff;
--foreground: #171717;
}
@media (prefers-color-scheme: dark) {
:root {
--background: #0a0a0a;
--foreground: #ededed;
}
}
body {
color: var(--foreground);
background: var(--background);
font-family: Arial, Helvetica, sans-serif;
}

BIN
ui/src/app/icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 KiB

3
ui/src/app/icon.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 110 KiB

View File

@@ -0,0 +1,80 @@
'use client';
import { useMemo, useState, use } from 'react';
import { FaChevronLeft } from 'react-icons/fa';
import { Button } from '@headlessui/react';
import { TopBar, MainContent } from '@/components/layout';
import useJob from '@/hooks/useJob';
import { startJob, stopJob } from '@/utils/jobs';
import SampleImages from '@/components/SampleImages';
import JobOverview from '@/components/JobOverview';
import { JobConfig } from '@/types';
import { redirect } from 'next/navigation';
import JobActionBar from '@/components/JobActionBar';
type PageKey = 'overview' | 'samples';
interface Page {
name: string;
value: PageKey;
}
const pages: Page[] = [
{ name: 'Overview', value: 'overview' },
{ name: 'Samples', value: 'samples' },
];
export default function JobPage({ params }: { params: { jobID: string } }) {
const usableParams = use(params as any) as { jobID: string };
const jobID = usableParams.jobID;
const { job, status, refreshJob } = useJob(jobID, 5000);
const [pageKey, setPageKey] = useState<PageKey>('overview');
return (
<>
{/* Fixed top bar */}
<TopBar>
<div>
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => redirect('/jobs')}>
<FaChevronLeft />
</Button>
</div>
<div>
<h1 className="text-lg">Job: {job?.name}</h1>
</div>
<div className="flex-1"></div>
{job && (
<JobActionBar
job={job}
onRefresh={refreshJob}
hideView
afterDelete={() => {
redirect('/jobs');
}}
/>
)}
</TopBar>
<MainContent className="pt-24">
{status === 'loading' && job == null && <p>Loading...</p>}
{status === 'error' && job == null && <p>Error fetching job</p>}
{job && (
<>
{pageKey === 'overview' && <JobOverview job={job} />}
{pageKey === 'samples' && <SampleImages job={job} />}
</>
)}
</MainContent>
<div className="bg-gray-800 absolute top-12 left-0 w-full h-8 flex items-center px-2 text-sm">
{pages.map(page => (
<Button
key={page.value}
onClick={() => setPageKey(page.value)}
className={`px-4 py-1 h-8 ${page.value === pageKey ? 'bg-gray-300 dark:bg-gray-700' : ''}`}
>
{page.name}
</Button>
))}
</div>
</>
);
}

View File

@@ -0,0 +1,101 @@
import { JobConfig, DatasetConfig } from '@/types';
export const defaultDatasetConfig: DatasetConfig = {
folder_path: '/path/to/images/folder',
mask_path: null,
mask_min_value: 0.1,
default_caption: '',
caption_ext: 'txt',
caption_dropout_rate: 0.05,
cache_latents_to_disk: false,
is_reg: false,
network_weight: 1,
resolution: [512, 768, 1024],
};
export const defaultJobConfig: JobConfig = {
job: 'extension',
config: {
name: 'my_first_flex_lora_v1',
process: [
{
type: 'ui_trainer',
training_folder: 'output',
sqlite_db_path: './aitk_db.db',
device: 'cuda:0',
trigger_word: null,
performance_log_every: 10,
network: {
type: 'lora',
linear: 16,
linear_alpha: 16,
},
save: {
dtype: 'bf16',
save_every: 250,
max_step_saves_to_keep: 4,
save_format: 'diffusers',
push_to_hub: false,
},
datasets: [
defaultDatasetConfig
],
train: {
batch_size: 1,
bypass_guidance_embedding: true,
steps: 2000,
gradient_accumulation: 1,
train_unet: true,
train_text_encoder: false,
gradient_checkpointing: true,
noise_scheduler: 'flowmatch',
optimizer: 'adamw8bit',
timestep_type: 'sigmoid',
content_or_style: 'balanced',
optimizer_params: {
weight_decay: 1e-4
},
lr: 0.0001,
ema_config: {
use_ema: true,
ema_decay: 0.99,
},
dtype: 'bf16',
},
model: {
name_or_path: 'ostris/Flex.1-alpha',
is_flux: true,
quantize: true,
quantize_te: true
},
sample: {
sampler: 'flowmatch',
sample_every: 250,
width: 1024,
height: 1024,
prompts: [
'woman with red hair, playing chess at the park, bomb going off in the background',
'a woman holding a coffee cup, in a beanie, sitting at a cafe',
'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
'a bear building a log cabin in the snow covered mountains',
'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',
'hipster man with a beard, building a chair, in a wood shop',
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
"a man holding a sign that says, 'this is a sign'",
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
],
neg: '',
seed: 42,
walk_seed: true,
guidance_scale: 4,
sample_steps: 25,
},
},
],
},
meta: {
name: '[name]',
version: '1.0',
},
};

View File

@@ -0,0 +1,41 @@
export interface Model {
name_or_path: string;
defaults?: { [key: string]: any };
}
export interface Option {
model: Model[];
}
export const options = {
model: [
{
name_or_path: 'ostris/Flex.1-alpha',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_flux': [true, false],
'config.process[0].train.bypass_guidance_embedding': [true, false],
},
},
{
name_or_path: 'black-forest-labs/FLUX.1-dev',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_flux': [true, false],
},
},
{
name_or_path: 'Alpha-VLLM/Lumina-Image-2.0',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_lumina2': [true, false],
},
},
],
} as Option;

View File

@@ -0,0 +1,618 @@
'use client';
import { useEffect, useState } from 'react';
import { useSearchParams, useRouter } from 'next/navigation';
import { options } from './options';
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
import { useNestedState } from '@/utils/hooks';
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
import Card from '@/components/Card';
import { X } from 'lucide-react';
import useSettings from '@/hooks/useSettings';
import useGPUInfo from '@/hooks/useGPUInfo';
import useDatasetList from '@/hooks/useDatasetList';
import path from 'path';
import { TopBar, MainContent } from '@/components/layout';
import { Button } from '@headlessui/react';
import { FaChevronLeft } from 'react-icons/fa';
export default function TrainingForm() {
const router = useRouter();
const searchParams = useSearchParams();
const runId = searchParams.get('id');
const [gpuIDs, setGpuIDs] = useState<string | null>(null);
const { settings, isSettingsLoaded } = useSettings();
const { gpuList, isGPUInfoLoaded } = useGPUInfo();
const { datasets, status: datasetFetchStatus } = useDatasetList();
const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]);
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
useEffect(() => {
if (!isSettingsLoaded) return;
if (datasetFetchStatus !== 'success') return;
const datasetOptions = datasets.map(name => ({ value: path.join(settings.DATASETS_FOLDER, name), label: name }));
setDatasetOptions(datasetOptions);
const defaultDatasetPath = defaultDatasetConfig.folder_path;
for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
const dataset = jobConfig.config.process[0].datasets[i];
if (dataset.folder_path === defaultDatasetPath) {
setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`);
}
}
}, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
useEffect(() => {
if (runId) {
fetch(`/api/jobs?id=${runId}`)
.then(res => res.json())
.then(data => {
setGpuIDs(data.gpu_ids);
setJobConfig(JSON.parse(data.job_config));
// setJobConfig(data.name, 'config.name');
})
.catch(error => console.error('Error fetching training:', error));
}
}, [runId]);
useEffect(() => {
if (isGPUInfoLoaded) {
if (gpuIDs === null && gpuList.length > 0) {
setGpuIDs(`${gpuList[0].index}`);
}
}
}, [gpuList, isGPUInfoLoaded]);
useEffect(() => {
if (isSettingsLoaded) {
setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder');
}
}, [settings, isSettingsLoaded]);
const saveJob = async () => {
if (status === 'saving') return;
setStatus('saving');
try {
const response = await fetch('/api/jobs', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
id: runId,
name: jobConfig.config.name,
gpu_ids: gpuIDs,
job_config: jobConfig,
}),
});
if (!response.ok) throw new Error('Failed to save training');
setStatus('success');
if (!runId) {
const data = await response.json();
router.push(`/jobs/${data.id}`);
}
setTimeout(() => setStatus('idle'), 2000);
} catch (error) {
console.error('Error saving training:', error);
setStatus('error');
setTimeout(() => setStatus('idle'), 2000);
}
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
saveJob();
};
console.log('jobConfig.config.process[0].network.linear', jobConfig?.config?.process[0].network?.linear);
return (
<>
<TopBar>
<div>
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
<FaChevronLeft />
</Button>
</div>
<div>
<h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1>
</div>
<div className="flex-1"></div>
<div>
<Button
className="text-gray-200 bg-green-800 px-3 py-1 rounded-md"
onClick={() => saveJob()}
disabled={status === 'saving'}
>
{status === 'saving' ? 'Saving...' : runId ? 'Update Job' : 'Create Job'}
</Button>
</div>
</TopBar>
<MainContent>
<form onSubmit={handleSubmit} className="space-y-8">
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<Card title="Job Settings">
<TextInput
label="Training Name"
value={jobConfig.config.name}
onChange={value => setJobConfig(value, 'config.name')}
placeholder="Enter training name"
disabled={runId !== null}
required
/>
<SelectInput
label="GPU ID"
value={`${gpuIDs}`}
onChange={value => setGpuIDs(value)}
options={gpuList.map(gpu => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
/>
<TextInput
label="Trigger Word"
value={jobConfig.config.process[0].trigger_word || ''}
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
}
setJobConfig(value, 'jobConfig.config.process[0].trigger_word');
}}
placeholder=""
required
/>
</Card>
{/* Model Configuration Section */}
<Card title="Model Configuration">
<SelectInput
label="Name or Path"
value={jobConfig.config.process[0].model.name_or_path}
onChange={value => {
// see if model changed
const currentModel = options.model.find(
model => model.name_or_path === jobConfig.config.process[0].model.name_or_path,
);
if (!currentModel || currentModel.name_or_path === value) {
// model has not changed
return;
}
// revert defaults from previous model
for (const key in currentModel.defaults) {
setJobConfig(currentModel.defaults[key][1], key);
}
// set new model
setJobConfig(value, 'config.process[0].model.name_or_path');
// update the defaults when a model is selected
const model = options.model.find(model => model.name_or_path === value);
if (model?.defaults) {
for (const key in model.defaults) {
setJobConfig(model.defaults[key][0], key);
}
}
}}
options={options.model.map(model => ({
value: model.name_or_path,
label: model.name_or_path,
}))}
/>
<FormGroup label="Quantize">
<div className='grid grid-cols-2 gap-2'>
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</div>
</FormGroup>
</Card>
{jobConfig.config.process[0].network?.type && (
<Card title="LoRA Configuration">
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
onChange={value => {
console.log('onChange', value);
setJobConfig(value, 'config.process[0].network.linear');
setJobConfig(value, 'config.process[0].network.linear_alpha');
}}
placeholder="eg. 16"
min={0}
max={1024}
required
/>
</Card>
)}
<Card title="Save Configuration">
<SelectInput
label="Data Type"
value={jobConfig.config.process[0].save.dtype}
onChange={value => setJobConfig(value, 'config.process[0].save.dtype')}
options={[
{ value: 'bf16', label: 'BF16' },
{ value: 'fp16', label: 'FP16' },
{ value: 'fp32', label: 'FP32' },
]}
/>
<NumberInput
label="Save Every"
value={jobConfig.config.process[0].save.save_every}
onChange={value => setJobConfig(value, 'config.process[0].save.save_every')}
placeholder="eg. 250"
min={1}
required
/>
<NumberInput
label="Max Step Saves to Keep"
value={jobConfig.config.process[0].save.max_step_saves_to_keep}
onChange={value => setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')}
placeholder="eg. 4"
min={1}
required
/>
</Card>
</div>
<div>
<Card title="Training Configuration">
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div>
<NumberInput
label="Batch Size"
value={jobConfig.config.process[0].train.batch_size}
onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
placeholder="eg. 4"
min={1}
required
/>
<NumberInput
label="Gradient Accumulation"
className="pt-2"
value={jobConfig.config.process[0].train.gradient_accumulation}
onChange={value => setJobConfig(value, 'config.process[0].train.gradient_accumulation')}
placeholder="eg. 1"
min={1}
required
/>
<NumberInput
label="Steps"
className="pt-2"
value={jobConfig.config.process[0].train.steps}
onChange={value => setJobConfig(value, 'config.process[0].train.steps')}
placeholder="eg. 2000"
min={1}
required
/>
</div>
<div>
<SelectInput
label="Optimizer"
value={jobConfig.config.process[0].train.optimizer}
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
options={[
{ value: 'adamw8bit', label: 'AdamW8Bit' },
{ value: 'adafactor', label: 'Adafactor' },
]}
/>
<NumberInput
label="Learning Rate"
className="pt-2"
value={jobConfig.config.process[0].train.lr}
onChange={value => setJobConfig(value, 'config.process[0].train.lr')}
placeholder="eg. 0.0001"
min={0}
required
/>
<NumberInput
label="Weight Decay"
className="pt-2"
value={jobConfig.config.process[0].train.optimizer_params.weight_decay}
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')}
placeholder="eg. 0.0001"
min={0}
required
/>
</div>
<div>
<SelectInput
label="Timestep Type"
value={jobConfig.config.process[0].train.timestep_type}
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
options={[
{ value: 'sigmoid', label: 'Sigmoid' },
{ value: 'linear', label: 'Linear' },
{ value: 'flux_shift', label: 'Flux Shift' },
]}
/>
<SelectInput
label="Timestep Bias"
className="pt-2"
value={jobConfig.config.process[0].train.content_or_style}
onChange={value => setJobConfig(value, 'config.process[0].train.content_or_style')}
options={[
{ value: 'balanced', label: 'Balanced' },
{ value: 'content', label: 'High Noise' },
{ value: 'style', label: 'Low Noise' },
]}
/>
<SelectInput
label="Noise Scheduler"
className="pt-2"
value={jobConfig.config.process[0].train.noise_scheduler}
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
/>
</div>
<div>
<FormGroup label="EMA (Exponential Moving Average)">
<Checkbox
label="Use EMA"
className='pt-1'
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
/>
</FormGroup>
<NumberInput
label="EMA Decay"
className="pt-2"
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
placeholder="eg. 0.99"
min={0}
/>
</div>
</div>
</Card>
</div>
<div>
<Card title="Datasets">
<>
{jobConfig.config.process[0].datasets.map((dataset, i) => (
<div key={i} className="p-4 rounded-lg bg-gray-800 relative">
<button
type="button"
onClick={() =>
setJobConfig(
jobConfig.config.process[0].datasets.filter((_, index) => index !== i),
'config.process[0].datasets',
)
}
className="absolute top-2 right-2 bg-red-800 hover:bg-red-700 rounded-full p-1 text-sm transition-colors"
>
<X />
</button>
<h2 className="text-lg font-bold mb-4">Dataset {i + 1}</h2>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div>
<SelectInput
label="Dataset"
value={dataset.folder_path}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
options={datasetOptions}
/>
<NumberInput
label="LoRA Weight"
value={dataset.network_weight}
className="pt-2"
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)}
placeholder="eg. 1.0"
/>
</div>
<div>
<TextInput
label="Default Caption"
value={dataset.default_caption}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
placeholder="eg. A photo of a cat"
/>
<NumberInput
label="Caption Dropout Rate"
className="pt-2"
value={dataset.caption_dropout_rate}
onChange={value =>
setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)
}
placeholder="eg. 0.05"
min={0}
required
/>
</div>
<div>
<FormGroup label="Settings" className="">
<Checkbox
label="Cache Latents"
checked={dataset.cache_latents_to_disk || false}
onChange={value =>
setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
}
/>
<Checkbox
label="Is Regularization"
checked={dataset.is_reg || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
/>
</FormGroup>
</div>
<div>
<FormGroup label="Resolutions" className="pt-2">
<div className="grid grid-cols-2 gap-2">
{[
[256, 512, 768],
[1024, 1280, 1536],
].map(resGroup => (
<div key={resGroup[0]} className="space-y-2">
{resGroup.map(res => (
<Checkbox
key={res}
label={res.toString()}
checked={dataset.resolution.includes(res)}
onChange={value => {
const resolutions = dataset.resolution.includes(res)
? dataset.resolution.filter(r => r !== res)
: [...dataset.resolution, res];
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
}}
/>
))}
</div>
))}
</div>
</FormGroup>
</div>
</div>
</div>
))}
<button
type="button"
onClick={() =>
setJobConfig(
[...jobConfig.config.process[0].datasets, objectCopy(defaultDatasetConfig)],
'config.process[0].datasets',
)
}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>
Add Dataset
</button>
</>
</Card>
</div>
<div>
<Card title="Sample Configuration">
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div>
<NumberInput
label="Sample Every"
value={jobConfig.config.process[0].sample.sample_every}
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_every')}
placeholder="eg. 250"
min={1}
required
/>
<SelectInput
label="Sampler"
className="pt-2"
value={jobConfig.config.process[0].sample.sampler}
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
/>
</div>
<div>
<NumberInput
label="Guidance Scale"
value={jobConfig.config.process[0].sample.guidance_scale}
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
placeholder="eg. 1.0"
min={0}
required
/>
<NumberInput
label="Sample Steps"
value={jobConfig.config.process[0].sample.sample_steps}
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_steps')}
placeholder="eg. 1"
className="pt-2"
min={1}
required
/>
</div>
<div>
<NumberInput
label="Width"
value={jobConfig.config.process[0].sample.width}
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
placeholder="eg. 1024"
min={256}
required
/>
<NumberInput
label="Height"
value={jobConfig.config.process[0].sample.height}
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
placeholder="eg. 1024"
className="pt-2"
min={256}
required
/>
</div>
<div>
<NumberInput
label="Seed"
value={jobConfig.config.process[0].sample.seed}
onChange={value => setJobConfig(value, 'config.process[0].sample.seed')}
placeholder="eg. 0"
min={0}
required
/>
<Checkbox
label="Walk Seed"
className="pt-4 pl-2"
checked={jobConfig.config.process[0].sample.walk_seed}
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
/>
</div>
</div>
<FormGroup
label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`}
className="pt-2"
>
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
<div key={i} className="flex items-center space-x-2">
<div className="flex-1">
<TextInput
value={prompt}
onChange={value => setJobConfig(value, `config.process[0].sample.prompts[${i}]`)}
placeholder="Enter prompt"
required
/>
</div>
<div>
<button
type="button"
onClick={() =>
setJobConfig(
jobConfig.config.process[0].sample.prompts.filter((_, index) => index !== i),
'config.process[0].sample.prompts',
)
}
className="rounded-full p-1 text-sm"
>
<X />
</button>
</div>
</div>
))}
<button
type="button"
onClick={() =>
setJobConfig(
[...jobConfig.config.process[0].sample.prompts, ''],
'config.process[0].sample.prompts',
)
}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>
Add Prompt
</button>
</FormGroup>
</Card>
</div>
{status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>}
{status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>}
</form>
<div className="pt-20"></div>
</MainContent>
</>
);
}

29
ui/src/app/jobs/page.tsx Normal file
View File

@@ -0,0 +1,29 @@
'use client';
import JobsTable from '@/components/JobsTable';
import { TopBar, MainContent } from '@/components/layout';
import Link from 'next/link';
export default function Dashboard() {
return (
<>
<TopBar>
<div>
<h1 className="text-lg">Training Jobs</h1>
</div>
<div className="flex-1"></div>
<div>
<Link
href="/jobs/new"
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
>
New Training Job
</Link>
</div>
</TopBar>
<MainContent>
<JobsTable />
</MainContent>
</>
);
}

38
ui/src/app/layout.tsx Normal file
View File

@@ -0,0 +1,38 @@
import type { Metadata } from 'next';
import { Inter } from 'next/font/google';
import './globals.css';
import Sidebar from '@/components/Sidebar';
import { ThemeProvider } from '@/components/ThemeProvider';
import ConfirmModal from '@/components/ConfirmModal';
import SampleImageModal from '@/components/SampleImageModal';
import { Suspense } from 'react';
const inter = Inter({ subsets: ['latin'] });
export const metadata: Metadata = {
title: 'Ostris - AI Toolkit',
description: 'A toolkit for building AI things.',
};
export default function RootLayout({ children }: { children: React.ReactNode }) {
return (
<html lang="en" className="dark">
<head>
<meta name="apple-mobile-web-app-title" content="AI-Toolkit" />
</head>
<body className={inter.className}>
<ThemeProvider>
<div className="flex h-screen bg-gray-950">
<Sidebar />
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
<Suspense>{children}</Suspense>
</main>
</div>
</ThemeProvider>
<ConfirmModal />
<SampleImageModal />
</body>
</html>
);
}

21
ui/src/app/manifest.json Normal file
View File

@@ -0,0 +1,21 @@
{
"name": "AI Toolkit",
"short_name": "AIToolkit",
"icons": [
{
"src": "/web-app-manifest-192x192.png",
"sizes": "192x192",
"type": "image/png",
"purpose": "maskable"
},
{
"src": "/web-app-manifest-512x512.png",
"sizes": "512x512",
"type": "image/png",
"purpose": "maskable"
}
],
"theme_color": "#000000",
"background_color": "#000000",
"display": "standalone"
}

5
ui/src/app/page.tsx Normal file
View File

@@ -0,0 +1,5 @@
import { redirect } from 'next/navigation';
export default function Home() {
redirect('/dashboard');
}

View File

@@ -0,0 +1,134 @@
'use client';
import { useEffect, useState } from 'react';
import useSettings from '@/hooks/useSettings';
import { TopBar, MainContent } from '@/components/layout';
export default function Settings() {
const { settings, setSettings } = useSettings();
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setStatus('saving');
try {
const response = await fetch('/api/settings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(settings),
});
if (!response.ok) throw new Error('Failed to save settings');
setStatus('success');
setTimeout(() => setStatus('idle'), 2000);
} catch (error) {
console.error('Error saving settings:', error);
setStatus('error');
setTimeout(() => setStatus('idle'), 2000);
}
};
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const { name, value } = e.target;
setSettings(prev => ({ ...prev, [name]: value }));
};
return (
<>
<TopBar>
<div>
<h1 className="text-lg">Settings</h1>
</div>
<div className="flex-1"></div>
</TopBar>
<MainContent>
<form onSubmit={handleSubmit} className="space-y-6">
<div className="grid grid-cols-1 gap-6 sm:grid-cols-2">
<div>
<div className="space-y-4">
<div>
<label htmlFor="HF_TOKEN" className="block text-sm font-medium mb-2">
Hugging Face Token
<div className="text-gray-500 text-sm ml-1">
Create a Read token on{' '}
<a href="https://huggingface.co/settings/tokens" target="_blank" rel="noreferrer">
{' '}
Huggingface
</a>{' '}
if you need to access gated/private models.
</div>
</label>
<input
type="password"
id="HF_TOKEN"
name="HF_TOKEN"
value={settings.HF_TOKEN}
onChange={handleChange}
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
placeholder="Enter your Hugging Face token"
/>
</div>
<div>
<label htmlFor="TRAINING_FOLDER" className="block text-sm font-medium mb-2">
Training Folder Path
<div className="text-gray-500 text-sm ml-1">
We will store your training information here. Must be an absolute path. If blank, it will default
to the output folder in the project root.
</div>
</label>
<input
type="text"
id="TRAINING_FOLDER"
name="TRAINING_FOLDER"
value={settings.TRAINING_FOLDER}
onChange={handleChange}
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
placeholder="Enter training folder path"
/>
</div>
<div>
<label htmlFor="DATASETS_FOLDER" className="block text-sm font-medium mb-2">
Dataset Folder Path
<div className="text-gray-500 text-sm ml-1">
Where we store and find your datasets.{' '}
<span className="text-orange-800">
Warning: This software may modify datasets so it is recommended you keep a backup somewhere else
or have a dedicated folder for this software.
</span>
</div>
</label>
<input
type="text"
id="DATASETS_FOLDER"
name="DATASETS_FOLDER"
value={settings.DATASETS_FOLDER}
onChange={handleChange}
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
placeholder="Enter datasets folder path"
/>
</div>
</div>
</div>
</div>
<button
type="submit"
disabled={status === 'saving'}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
>
{status === 'saving' ? 'Saving...' : 'Save Settings'}
</button>
{status === 'success' && <p className="text-green-500 text-center">Settings saved successfully!</p>}
{status === 'error' && <p className="text-red-500 text-center">Error saving settings. Please try again.</p>}
</form>
</MainContent>
</>
);
}

View File

@@ -0,0 +1,155 @@
'use client';
import { createGlobalState } from 'react-global-hooks';
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
import { FaUpload } from 'react-icons/fa';
import { useCallback, useState } from 'react';
import { useDropzone } from 'react-dropzone';
import axios from 'axios';
export interface AddImagesModalState {
datasetName: string;
onComplete?: () => void;
}
export const addImagesModalState = createGlobalState<AddImagesModalState | null>(null);
export const openImagesModal = (datasetName: string, onComplete: () => void) => {
addImagesModalState.set({ datasetName, onComplete });
}
export default function AddImagesModal() {
const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use();
const [uploadProgress, setUploadProgress] = useState<number>(0);
const [isUploading, setIsUploading] = useState<boolean>(false);
const open = addImagesModalInfo !== null;
const onCancel = () => {
if (!isUploading) {
setAddImagesModalInfo(null);
}
};
const onDone = () => {
if (addImagesModalInfo?.onComplete && !isUploading) {
addImagesModalInfo.onComplete();
setAddImagesModalInfo(null);
}
};
const onDrop = useCallback(async (acceptedFiles: File[]) => {
if (acceptedFiles.length === 0) return;
setIsUploading(true);
setUploadProgress(0);
const formData = new FormData();
acceptedFiles.forEach(file => {
formData.append('files', file);
});
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
try {
await axios.post(`/api/datasets/upload`, formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
onUploadProgress: (progressEvent) => {
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
setUploadProgress(percentCompleted);
},
timeout: 0, // Disable timeout
});
onDone();
} catch (error) {
console.error('Upload failed:', error);
} finally {
setIsUploading(false);
setUploadProgress(0);
}
}, [addImagesModalInfo]);
const { getRootProps, getInputProps, isDragActive } = useDropzone({
onDrop,
accept: {
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp'],
'text/*': ['.txt']
},
multiple: true
});
return (
<Dialog open={open} onClose={onCancel} className="relative z-10">
<DialogBackdrop
transition
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
/>
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
<DialogPanel
transition
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
>
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
<div className="text-center">
<DialogTitle as="h3" className="text-base font-semibold text-gray-200 mb-4">
Add Images to: {addImagesModalInfo?.datasetName}
</DialogTitle>
<div className="w-full">
<div
{...getRootProps()}
className={`h-40 w-full flex flex-col items-center justify-center border-2 border-dashed rounded-lg cursor-pointer transition-colors duration-200
${isDragActive ? 'border-blue-500 bg-blue-50/10' : 'border-gray-600'}`}
>
<input {...getInputProps()} />
<FaUpload className="size-8 mb-3 text-gray-400" />
<p className="text-sm text-gray-200 text-center">
{isDragActive
? 'Drop the files here...'
: 'Drag & drop files here, or click to select files'}
</p>
</div>
{isUploading && (
<div className="mt-4">
<div className="w-full bg-gray-700 rounded-full h-2.5">
<div
className="bg-blue-600 h-2.5 rounded-full"
style={{ width: `${uploadProgress}%` }}
></div>
</div>
<p className="text-sm text-gray-300 mt-2 text-center">
Uploading... {uploadProgress}%
</p>
</div>
)}
</div>
</div>
</div>
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
<button
type="button"
onClick={onDone}
disabled={isUploading}
className={`inline-flex w-full justify-center rounded-md bg-slate-600 px-3 py-2 text-sm font-semibold text-white shadow-xs sm:ml-3 sm:w-auto
${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
>
Done
</button>
<button
type="button"
data-autofocus
onClick={onCancel}
disabled={isUploading}
className={`mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0
${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
>
Cancel
</button>
</div>
</DialogPanel>
</div>
</div>
</Dialog>
);
}

View File

@@ -0,0 +1,15 @@
interface CardProps {
title?: string;
children?: React.ReactNode;
}
const Card: React.FC<CardProps> = ({ title, children }) => {
return (
<section className="space-y-2 px-4 pb-4 pt-2 bg-gray-900 rounded-lg">
{title && <h2 className="text-lg mb-2 font-semibold uppercase text-gray-500">{title}</h2>}
{children ? children : null}
</section>
);
};
export default Card;

View File

@@ -0,0 +1,171 @@
'use client';
import { useState, useEffect } from 'react';
import { createGlobalState } from 'react-global-hooks';
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
import { FaExclamationTriangle, FaInfo } from 'react-icons/fa';
export interface ConfirmState {
title: string;
message?: string;
confirmText?: string;
type?: 'danger' | 'warning' | 'info';
onConfirm?: () => void;
onCancel?: () => void;
}
export const confirmstate = createGlobalState<ConfirmState | null>(null);
export const openConfirm = (confirmProps: ConfirmState) => {
confirmstate.set(confirmProps);
};
export default function ConfirmModal() {
const [confirm, setConfirm] = confirmstate.use();
const [isOpen, setIsOpen] = useState(false);
useEffect(() => {
if (confirm) {
setIsOpen(true);
}
}, [confirm]);
useEffect(() => {
if (!isOpen) {
// use timeout to allow the dialog to close before resetting the state
setTimeout(() => {
setConfirm(null);
}, 500);
}
}, [isOpen]);
const onCancel = () => {
if (confirm?.onCancel) {
confirm.onCancel();
}
setIsOpen(false);
};
const onConfirm = () => {
if (confirm?.onConfirm) {
confirm.onConfirm();
}
setIsOpen(false);
};
let Icon = FaExclamationTriangle;
let color = confirm?.type || 'danger';
// Use conditional rendering for icon
if (color === 'info') {
Icon = FaInfo;
}
// Color mapping for background colors
const getBgColor = () => {
switch (color) {
case 'danger':
return 'bg-red-500';
case 'warning':
return 'bg-yellow-500';
case 'info':
return 'bg-blue-500';
default:
return 'bg-red-500';
}
};
// Color mapping for text colors
const getTextColor = () => {
switch (color) {
case 'danger':
return 'text-red-950';
case 'warning':
return 'text-yellow-950';
case 'info':
return 'text-blue-950';
default:
return 'text-red-950';
}
};
// Color mapping for titles
const getTitleColor = () => {
switch (color) {
case 'danger':
return 'text-red-500';
case 'warning':
return 'text-yellow-500';
case 'info':
return 'text-blue-500';
default:
return 'text-red-500';
}
};
// Button background color mapping
const getButtonBgColor = () => {
switch (color) {
case 'danger':
return 'bg-red-700 hover:bg-red-500';
case 'warning':
return 'bg-yellow-700 hover:bg-yellow-500';
case 'info':
return 'bg-blue-700 hover:bg-blue-500';
default:
return 'bg-red-700 hover:bg-red-500';
}
};
return (
<Dialog open={isOpen} onClose={onCancel} className="relative z-10">
<DialogBackdrop
transition
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
/>
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
<DialogPanel
transition
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
>
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
<div className="sm:flex sm:items-start">
<div
className={`mx-auto flex size-12 shrink-0 items-center justify-center rounded-full ${getBgColor()} sm:mx-0 sm:size-10`}
>
<Icon aria-hidden="true" className={`size-6 ${getTextColor()}`} />
</div>
<div className="mt-3 text-center sm:mt-0 sm:ml-4 sm:text-left">
<DialogTitle as="h3" className={`text-base font-semibold ${getTitleColor()}`}>
{confirm?.title}
</DialogTitle>
<div className="mt-2">
<p className="text-sm text-gray-200">{confirm?.message}</p>
</div>
</div>
</div>
</div>
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
<button
type="button"
onClick={onConfirm}
className={`inline-flex w-full justify-center rounded-md ${getButtonBgColor()} px-3 py-2 text-sm font-semibold text-white shadow-xs sm:ml-3 sm:w-auto`}
>
{confirm?.confirmText || 'Confirm'}
</button>
<button
type="button"
data-autofocus
onClick={onCancel}
className="mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0"
>
Cancel
</button>
</div>
</DialogPanel>
</div>
</div>
</Dialog>
);
}

View File

@@ -0,0 +1,186 @@
import React, { useRef, useEffect, useState, ReactNode, KeyboardEvent } from 'react';
import { FaTrashAlt } from 'react-icons/fa';
import { openConfirm } from './ConfirmModal';
import classNames from 'classnames';
interface DatasetImageCardProps {
imageUrl: string;
alt: string;
children?: ReactNode;
className?: string;
onDelete?: () => void;
}
const DatasetImageCard: React.FC<DatasetImageCardProps> = ({
imageUrl,
alt,
children,
className = '',
onDelete = () => {},
}) => {
const cardRef = useRef<HTMLDivElement>(null);
const [isVisible, setIsVisible] = useState<boolean>(false);
const [loaded, setLoaded] = useState<boolean>(false);
const [isCaptionLoaded, setIsCaptionLoaded] = useState<boolean>(false);
const [caption, setCaption] = useState<string>('');
const [savedCaption, setSavedCaption] = useState<string>('');
const isGettingCaption = useRef<boolean>(false);
const fetchCaption = async () => {
try {
if (isGettingCaption.current || isCaptionLoaded) return;
isGettingCaption.current = true;
const response = await fetch(`/api/caption/${encodeURIComponent(imageUrl)}`);
const data = await response.text();
setCaption(data);
setSavedCaption(data);
setIsCaptionLoaded(true);
} catch (error) {
console.error('Error fetching caption:', error);
}
};
const saveCaption = () => {
const trimmedCaption = caption.trim();
if (trimmedCaption === savedCaption) return;
fetch('/api/img/caption', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ imgPath: imageUrl, caption: trimmedCaption }),
})
.then(res => res.json())
.then(data => {
console.log('Caption saved:', data);
setSavedCaption(trimmedCaption);
})
.catch(error => {
console.error('Error saving caption:', error);
});
};
useEffect(() => {
isVisible && fetchCaption();
}, [isVisible]);
useEffect(() => {
// Create intersection observer to check visibility
const observer = new IntersectionObserver(
entries => {
if (entries[0].isIntersecting) {
setIsVisible(true);
observer.disconnect();
}
},
{ threshold: 0.1 },
);
if (cardRef.current) {
observer.observe(cardRef.current);
}
return () => {
observer.disconnect();
};
}, []);
const handleLoad = (): void => {
setLoaded(true);
};
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>): void => {
// If Enter is pressed without Shift, prevent default behavior and save
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
saveCaption();
}
};
const isCaptionCurrent = caption.trim() === savedCaption;
return (
<div className={`flex flex-col ${className}`}>
{/* Square image container */}
<div
ref={cardRef}
className="relative w-full"
style={{ paddingBottom: '100%' }} // Make it square
>
<div className="absolute inset-0 rounded-t-lg shadow-md">
{isVisible && (
<img
src={`/api/img/${encodeURIComponent(imageUrl)}`}
alt={alt}
onLoad={handleLoad}
className={`w-full h-full object-contain transition-opacity duration-300 ${
loaded ? 'opacity-100' : 'opacity-0'
}`}
/>
)}
{children && <div className="absolute inset-0 flex items-center justify-center">{children}</div>}
<div className="absolute top-1 right-1">
<button
className="bg-gray-800 rounded-full p-2"
onClick={() => {
openConfirm({
title: 'Delete Image',
message: 'Are you sure you want to delete this image? This action cannot be undone.',
type: 'warning',
confirmText: 'Delete',
onConfirm: () => {
fetch('/api/img/delete', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ imgPath: imageUrl }),
})
.then(res => res.json())
.then(data => {
console.log('Image deleted:', data);
onDelete();
})
.catch(error => {
console.error('Error deleting image:', error);
});
},
});
}}
>
<FaTrashAlt />
</button>
</div>
</div>
</div>
{/* Text area below the image */}
<div
className={classNames('w-full p-2 bg-gray-800 text-white text-sm rounded-b-lg h-[75px]', {
'border-blue-500 border-2': !isCaptionCurrent,
'border-transparent border-2': isCaptionCurrent,
})}
>
{isVisible && isCaptionLoaded && (
<form
onSubmit={e => {
e.preventDefault();
saveCaption();
}}
onBlur={saveCaption}
>
<textarea
className="w-full bg-transparent resize-none outline-none focus:ring-0 focus:outline-none"
value={caption}
rows={3}
onChange={e => setCaption(e.target.value)}
onKeyDown={handleKeyDown}
/>
</form>
)}
</div>
</div>
);
};
export default DatasetImageCard;

View File

@@ -0,0 +1,90 @@
import React from 'react';
import useFilesList from '@/hooks/useFilesList';
import Link from 'next/link';
import { Loader2, AlertCircle, Download, Box, Brain } from 'lucide-react';
export default function FilesWidget({ jobID }: { jobID: string }) {
const { files, status, refreshFiles } = useFilesList(jobID, 5000);
const cleanSize = (size: number) => {
if (size < 1024) {
return `${size} B`;
} else if (size < 1024 * 1024) {
return `${(size / 1024).toFixed(1)} KB`;
} else if (size < 1024 * 1024 * 1024) {
return `${(size / (1024 * 1024)).toFixed(1)} MB`;
} else {
return `${(size / (1024 * 1024 * 1024)).toFixed(1)} GB`;
}
};
return (
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden hover:shadow-2xl transition-all duration-300 border border-gray-800">
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
<div className="flex items-center space-x-2">
<Brain className="w-5 h-5 text-purple-400" />
<h2 className="font-semibold text-gray-100">Checkpoints</h2>
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
{files.length}
</span>
</div>
</div>
<div className="p-2">
{status === 'loading' && (
<div className="flex items-center justify-center py-4">
<Loader2 className="w-5 h-5 text-gray-400 animate-spin" />
</div>
)}
{status === 'error' && (
<div className="flex items-center justify-center py-4 text-rose-400 space-x-2">
<AlertCircle className="w-4 h-4" />
<span className="text-sm">Error loading checkpoints</span>
</div>
)}
{['success', 'refreshing'].includes(status) && (
<div className="space-y-1">
{files.map((file, index) => {
const fileName = file.path.split('/').pop() || '';
const nameWithoutExt = fileName.replace('.safetensors', '');
return (
<a
key={index}
target='_blank'
href={`/api/files/${encodeURIComponent(file.path)}`}
className="group flex items-center justify-between px-2 py-1.5 rounded-lg hover:bg-gray-800 transition-all duration-200"
>
<div className="flex items-center space-x-2 min-w-0">
<Box className="w-4 h-4 text-purple-400 flex-shrink-0" />
<div className="flex flex-col min-w-0">
<div className="flex text-sm text-gray-200">
<span className="overflow-hidden text-ellipsis direction-rtl whitespace-nowrap">
{nameWithoutExt}
</span>
</div>
<span className="text-xs text-gray-500">.safetensors</span>
</div>
</div>
<div className="flex items-center space-x-3 flex-shrink-0">
<span className="text-xs text-gray-400">{cleanSize(file.size)}</span>
<div className="bg-purple-500 bg-opacity-0 group-hover:bg-opacity-10 rounded-full p-1 transition-all">
<Download className="w-3 h-3 text-purple-400" />
</div>
</div>
</a>
);
})}
</div>
)}
{['success', 'refreshing'].includes(status) && files.length === 0 && (
<div className="text-center py-4 text-gray-400 text-sm">
No checkpoints available
</div>
)}
</div>
</div>
);
}

View File

@@ -0,0 +1,123 @@
import React, { useState, useEffect } from 'react';
import { GPUApiResponse } from '@/types';
import Loading from '@/components/Loading';
import GPUWidget from '@/components/GPUWidget';
const GpuMonitor: React.FC = () => {
const [gpuData, setGpuData] = useState<GPUApiResponse | null>(null);
const [loading, setLoading] = useState<boolean>(true);
const [error, setError] = useState<string | null>(null);
const [lastUpdated, setLastUpdated] = useState<Date | null>(null);
useEffect(() => {
const fetchGpuInfo = async () => {
try {
const response = await fetch('/api/gpu');
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data: GPUApiResponse = await response.json();
setGpuData(data);
setLastUpdated(new Date());
setError(null);
} catch (err) {
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
} finally {
setLoading(false);
}
};
// Fetch immediately on component mount
fetchGpuInfo();
// Set up interval to fetch every 1 seconds
const intervalId = setInterval(fetchGpuInfo, 1000);
// Clean up interval on component unmount
return () => clearInterval(intervalId);
}, []);
const getGridClasses = (gpuCount: number): string => {
switch (gpuCount) {
case 1:
return 'grid-cols-1';
case 2:
return 'grid-cols-2';
case 3:
return 'grid-cols-3';
case 4:
return 'grid-cols-4';
case 5:
case 6:
return 'grid-cols-3';
case 7:
case 8:
return 'grid-cols-4';
case 9:
case 10:
return 'grid-cols-5';
default:
return 'grid-cols-3';
}
};
if (loading) {
return <Loading />;
}
if (error) {
return (
<div className="bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded relative" role="alert">
<strong className="font-bold">Error!</strong>
<span className="block sm:inline"> {error}</span>
</div>
);
}
if (!gpuData) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<span className="block sm:inline">No GPU data available.</span>
</div>
);
}
if (!gpuData.hasNvidiaSmi) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
</div>
);
}
if (gpuData.gpus.length === 0) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
</div>
);
}
const gridClass = getGridClasses(gpuData.gpus.length);
return (
<div className="w-full">
<div className="flex justify-between items-center mb-2">
<h1 className="text-md">GPU Monitor</h1>
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
</div>
<div className={`grid ${gridClass} gap-3`}>
{gpuData.gpus.map((gpu, idx) => (
<GPUWidget key={idx} gpu={gpu} />
))}
</div>
</div>
);
};
export default GpuMonitor;

View File

@@ -0,0 +1,110 @@
import React from 'react';
import { GpuInfo } from '@/types';
import { ChevronRight, Thermometer, Zap, Clock, HardDrive, Fan, Cpu } from 'lucide-react';
interface GPUWidgetProps {
gpu: GpuInfo;
}
export default function GPUWidget({ gpu }: GPUWidgetProps) {
const formatMemory = (mb: number): string => {
return mb >= 1024 ? `${(mb / 1024).toFixed(1)} GB` : `${mb} MB`;
};
const getUtilizationColor = (value: number): string => {
return value < 30 ? 'bg-emerald-500' : value < 70 ? 'bg-amber-500' : 'bg-rose-500';
};
const getTemperatureColor = (temp: number): string => {
return temp < 50 ? 'text-emerald-500' : temp < 80 ? 'text-amber-500' : 'text-rose-500';
};
return (
<div className="bg-gray-900 rounded-xl shadow-lg overflow-hidden hover:shadow-2xl transition-all duration-300 border border-gray-800">
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
<div className="flex items-center space-x-2">
<h2 className="font-semibold text-gray-100">{gpu.name}</h2>
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
#{gpu.index}
</span>
</div>
</div>
<div className="p-4 space-y-4">
{/* Temperature, Fan, and Utilization Section */}
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<div className="flex items-center space-x-2">
<Thermometer className={`w-4 h-4 ${getTemperatureColor(gpu.temperature)}`} />
<div>
<p className="text-xs text-gray-400">Temperature</p>
<p className={`text-sm font-medium ${getTemperatureColor(gpu.temperature)}`}>
{gpu.temperature}°C
</p>
</div>
</div>
<div className="flex items-center space-x-2">
<Fan className="w-4 h-4 text-blue-400" />
<div>
<p className="text-xs text-gray-400">Fan Speed</p>
<p className="text-sm font-medium text-blue-400">
{gpu.fan.speed}%
</p>
</div>
</div>
</div>
<div>
<div className="flex items-center space-x-2 mb-1">
<Cpu className="w-4 h-4 text-gray-400" />
<p className="text-xs text-gray-400">GPU Load</p>
<span className="text-xs text-gray-300 ml-auto">{gpu.utilization.gpu}%</span>
</div>
<div className="w-full bg-gray-700 rounded-full h-1">
<div
className={`h-1 rounded-full transition-all ${getUtilizationColor(gpu.utilization.gpu)}`}
style={{ width: `${gpu.utilization.gpu}%` }}
/>
</div>
<div className="flex items-center space-x-2 mb-1 mt-3">
<HardDrive className="w-4 h-4 text-blue-400" />
<p className="text-xs text-gray-400">Memory</p>
<span className="text-xs text-gray-300 ml-auto">
{((gpu.memory.used / gpu.memory.total) * 100).toFixed(1)}%
</span>
</div>
<div className="w-full bg-gray-700 rounded-full h-1">
<div
className="h-1 rounded-full bg-blue-500 transition-all"
style={{ width: `${(gpu.memory.used / gpu.memory.total) * 100}%` }}
/>
</div>
<p className="text-xs text-gray-400 mt-0.5">
{formatMemory(gpu.memory.used)} / {formatMemory(gpu.memory.total)}
</p>
</div>
</div>
{/* Power and Clocks Section */}
<div className="grid grid-cols-2 gap-4 pt-2 border-t border-gray-800">
<div className="flex items-start space-x-2">
<Clock className="w-4 h-4 text-purple-400" />
<div>
<p className="text-xs text-gray-400">Clock Speed</p>
<p className="text-sm text-gray-200">{gpu.clocks.graphics} MHz</p>
</div>
</div>
<div className="flex items-start space-x-2">
<Zap className="w-4 h-4 text-amber-400" />
<div>
<p className="text-xs text-gray-400">Power Draw</p>
<p className="text-sm text-gray-200">
{gpu.power.draw.toFixed(1)}W
<span className="text-gray-400 text-xs"> / {gpu.power.limit.toFixed(1)}W</span>
</p>
</div>
</div>
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,87 @@
import Link from 'next/link';
import { Eye, Trash2, Pen, Play, Pause } from 'lucide-react';
import { Button } from '@headlessui/react';
import { openConfirm } from '@/components/ConfirmModal';
import { Job } from '@prisma/client';
import { startJob, stopJob, deleteJob, getAvaliableJobActions } from '@/utils/jobs';
interface JobActionBarProps {
job: Job;
onRefresh?: () => void;
afterDelete?: () => void;
hideView?: boolean;
className?: string;
}
export default function JobActionBar({ job, onRefresh, afterDelete, className, hideView }: JobActionBarProps) {
const { canStart, canStop, canDelete, canEdit } = getAvaliableJobActions(job);
if (!afterDelete) afterDelete = onRefresh;
return (
<div className={`${className}`}>
{canStart && (
<Button
onClick={async () => {
if (!canStart) return;
await startJob(job.id);
if (onRefresh) onRefresh();
}}
className={`ml-2 opacity-100`}
>
<Play />
</Button>
)}
{canStop && (
<Button
onClick={() => {
if (!canStop) return;
openConfirm({
title: 'Stop Job',
message: `Are you sure you want to stop the job "${job.name}"? You CAN resume later.`,
type: 'info',
confirmText: 'Stop',
onConfirm: async () => {
await stopJob(job.id);
if (onRefresh) onRefresh();
},
});
}}
className={`ml-2 opacity-100`}
>
<Pause />
</Button>
)}
{!hideView && (
<Link href={`/jobs/${job.id}`} className="ml-2 text-gray-200 hover:text-gray-100 inline-block">
<Eye />
</Link>
)}
{canEdit && (
<Link href={`/jobs/new?id=${job.id}`} className="ml-2 hover:text-gray-100 inline-block">
<Pen />
</Link>
)}
{canDelete && (
<Button
onClick={() => {
if (!canDelete) return;
openConfirm({
title: 'Delete Job',
message: `Are you sure you want to delete the job "${job.name}"? This will also permanently remove it from your disk.`,
type: 'warning',
confirmText: 'Delete',
onConfirm: async () => {
await deleteJob(job.id);
if (afterDelete) afterDelete();
},
});
}}
className={`ml-2 opacity-100`}
>
<Trash2 />
</Button>
)}
</div>
);
}

View File

@@ -0,0 +1,104 @@
import { Job } from '@prisma/client';
import useGPUInfo from '@/hooks/useGPUInfo';
import GPUWidget from '@/components/GPUWidget';
import FilesWidget from '@/components/FilesWidget';
import { getTotalSteps } from '@/utils/jobs';
import { Cpu, HardDrive, Info, Gauge } from 'lucide-react';
import { useMemo } from 'react';
interface JobOverviewProps {
job: Job;
}
export default function JobOverview({ job }: JobOverviewProps) {
const gpuIds = useMemo(() => job.gpu_ids.split(',').map(id => parseInt(id)), [job.gpu_ids]);
const { gpuList, isGPUInfoLoaded } = useGPUInfo(gpuIds, 5000);
const totalSteps = getTotalSteps(job);
const progress = (job.step / totalSteps) * 100;
const isStopping = job.stop && job.status === 'running';
const getStatusColor = (status: string) => {
switch (status.toLowerCase()) {
case 'running':
return 'bg-emerald-500/10 text-emerald-500';
case 'stopping':
return 'bg-amber-500/10 text-amber-500';
case 'stopped':
return 'bg-gray-500/10 text-gray-400';
case 'completed':
return 'bg-blue-500/10 text-blue-500';
case 'error':
return 'bg-rose-500/10 text-rose-500';
default:
return 'bg-gray-500/10 text-gray-400';
}
};
let status = job.status;
if (isStopping) {
status = 'stopping';
}
return (
<div className="grid grid-cols-1 gap-6 md:grid-cols-3">
{/* Job Information Panel */}
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800">
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
<h2 className="text-gray-100"><Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}</h2>
<span className={`px-3 py-1 rounded-full text-sm ${getStatusColor(job.status)}`}>{job.status}</span>
</div>
<div className="p-4 space-y-6">
{/* Progress Bar */}
<div className="space-y-2">
<div className="flex items-center justify-between text-sm">
<span className="text-gray-400">Progress</span>
<span className="text-gray-200">
Step {job.step} of {totalSteps}
</span>
</div>
<div className="w-full bg-gray-800 rounded-full h-2">
<div className="h-2 rounded-full bg-blue-500 transition-all" style={{ width: `${progress}%` }} />
</div>
</div>
{/* Job Info Grid */}
<div className="grid gap-4 grid-cols-1 md:grid-cols-3">
<div className="flex items-center space-x-4">
<HardDrive className="w-5 h-5 text-blue-400" />
<div>
<p className="text-xs text-gray-400">Job Name</p>
<p className="text-sm font-medium text-gray-200">{job.name}</p>
</div>
</div>
<div className="flex items-center space-x-4">
<Cpu className="w-5 h-5 text-purple-400" />
<div>
<p className="text-xs text-gray-400">Assigned GPUs</p>
<p className="text-sm font-medium text-gray-200">GPUs: {job.gpu_ids}</p>
</div>
</div>
<div className="flex items-center space-x-4">
<Gauge className="w-5 h-5 text-green-400" />
<div>
<p className="text-xs text-gray-400">Speed</p>
<p className="text-sm font-medium text-gray-200">{job.speed_string == "" ? "?" : job.speed_string}</p>
</div>
</div>
</div>
</div>
</div>
{/* GPU Widget Panel */}
<div className="col-span-1">
<div>{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
<div className="mt-4">
<FilesWidget jobID={job.id} />
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,79 @@
import useJobsList from '@/hooks/useJobsList';
import Link from 'next/link';
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
import { JobConfig } from '@/types';
import JobActionBar from './JobActionBar';
interface JobsTableProps {
onlyActive?: boolean;
}
export default function JobsTable({ onlyActive = false }: JobsTableProps) {
const { jobs, status, refreshJobs } = useJobsList(onlyActive);
const isLoading = status === 'loading';
const columns: TableColumn[] = [
{
title: 'Name',
key: 'name',
render: row => (
<Link href={`/jobs/${row.id}`} className="font-medium whitespace-nowrap">
{row.name}
</Link>
),
},
{
title: 'Steps',
key: 'steps',
render: row => {
const jobConfig: JobConfig = JSON.parse(row.job_config);
const totalSteps = jobConfig.config.process[0].train.steps;
return (
<div className="flex items-center">
<span>
{row.step} / {totalSteps}
</span>
<div className="w-16 bg-gray-700 rounded-full h-1.5 ml-2">
<div
className="bg-blue-500 h-1.5 rounded-full"
style={{ width: `${(row.step / totalSteps) * 100}%` }}
></div>
</div>
</div>
);
},
},
{
title: 'GPU',
key: 'gpu_ids',
},
{
title: 'Status',
key: 'status',
render: row => {
let statusClass = 'text-gray-400';
if (row.status === 'completed') statusClass = 'text-green-400';
if (row.status === 'failed') statusClass = 'text-red-400';
if (row.status === 'running') statusClass = 'text-blue-400';
return <span className={statusClass}>{row.status}</span>;
},
},
{
title: 'Info',
key: 'info',
className: 'truncate max-w-xs',
},
{
title: 'Actions',
key: 'actions',
className: 'text-right',
render: row => {
return <JobActionBar job={row} onRefresh={refreshJobs} />;
},
},
];
return <UniversalTable columns={columns} rows={jobs} isLoading={isLoading} onRefresh={refreshJobs} />;
}

View File

@@ -0,0 +1,7 @@
export default function Loading() {
return (
<div className="flex justify-center items-center h-64">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500"></div>
</div>
);
}

110
ui/src/components/Modal.tsx Normal file
View File

@@ -0,0 +1,110 @@
import React, { Fragment, useEffect } from 'react';
interface ModalProps {
isOpen: boolean;
onClose: () => void;
title?: string;
children: React.ReactNode;
showCloseButton?: boolean;
size?: 'sm' | 'md' | 'lg' | 'xl';
closeOnOverlayClick?: boolean;
}
export const Modal: React.FC<ModalProps> = ({
isOpen,
onClose,
title,
children,
showCloseButton = true,
size = 'md',
closeOnOverlayClick = true,
}) => {
// Close on ESC key press
useEffect(() => {
const handleEscKey = (e: KeyboardEvent) => {
if (e.key === 'Escape' && isOpen) {
onClose();
}
};
if (isOpen) {
document.addEventListener('keydown', handleEscKey);
// Prevent body scrolling when modal is open
document.body.style.overflow = 'hidden';
}
return () => {
document.removeEventListener('keydown', handleEscKey);
document.body.style.overflow = 'auto';
};
}, [isOpen, onClose]);
// Handle overlay click
const handleOverlayClick = (e: React.MouseEvent<HTMLDivElement>) => {
if (e.target === e.currentTarget && closeOnOverlayClick) {
onClose();
}
};
if (!isOpen) return null;
// Size mapping
const sizeClasses = {
sm: 'max-w-md',
md: 'max-w-lg',
lg: 'max-w-2xl',
xl: 'max-w-4xl',
};
return (
<Fragment>
{/* Modal backdrop */}
<div
className="fixed inset-0 z-50 flex items-center justify-center overflow-y-auto bg-gray-900 bg-opacity-75 backdrop-blur-sm transition-opacity"
onClick={handleOverlayClick}
aria-modal="true"
role="dialog"
aria-labelledby="modal-title"
>
{/* Modal panel */}
<div
className={`relative mx-auto w-full ${sizeClasses[size]} rounded-lg bg-gray-800 border border-gray-700 shadow-xl transition-all`}
onClick={e => e.stopPropagation()}
>
{/* Modal header */}
{(title || showCloseButton) && (
<div className="flex items-center justify-between rounded-t-lg border-b border-gray-700 bg-gray-850 px-6 py-4">
{title && (
<h3 id="modal-title" className="text-xl font-semibold text-gray-100">
{title}
</h3>
)}
{showCloseButton && (
<button
type="button"
className="ml-auto inline-flex items-center justify-center rounded-md p-2 text-gray-400 hover:bg-gray-700 hover:text-gray-200 focus:outline-none focus:ring-2 focus:ring-blue-500"
onClick={onClose}
aria-label="Close modal"
>
<svg
className="h-5 w-5"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
)}
</div>
)}
{/* Modal content */}
<div className="px-6 py-4">{children}</div>
</div>
</div>
</Fragment>
);
};

View File

@@ -0,0 +1,71 @@
import React, { useRef, useEffect, useState, ReactNode } from 'react';
import { sampleImageModalState } from '@/components/SampleImageModal';
interface SampleImageCardProps {
imageUrl: string;
alt: string;
numSamples: number;
sampleImages: string[];
children?: ReactNode;
className?: string;
onDelete?: () => void;
}
const SampleImageCard: React.FC<SampleImageCardProps> = ({ imageUrl, alt, numSamples, sampleImages, children, className = '' }) => {
const cardRef = useRef<HTMLDivElement>(null);
const [isVisible, setIsVisible] = useState<boolean>(false);
const [loaded, setLoaded] = useState<boolean>(false);
useEffect(() => {
// Create intersection observer to check visibility
const observer = new IntersectionObserver(
entries => {
if (entries[0].isIntersecting) {
setIsVisible(true);
observer.disconnect();
}
},
{ threshold: 0.1 },
);
if (cardRef.current) {
observer.observe(cardRef.current);
}
return () => {
observer.disconnect();
};
}, []);
const handleLoad = (): void => {
setLoaded(true);
};
return (
<div className={`flex flex-col ${className}`}>
{/* Square image container */}
<div
ref={cardRef}
className="relative w-full cursor-pointer"
style={{ paddingBottom: '100%' }} // Make it square
onClick={() => sampleImageModalState.set({ imgPath: imageUrl, numSamples, sampleImages })}
>
<div className="absolute inset-0 rounded-t-lg shadow-md">
{isVisible && (
<img
src={`/api/img/${encodeURIComponent(imageUrl)}`}
alt={alt}
onLoad={handleLoad}
className={`w-full h-full object-contain transition-opacity duration-300 ${
loaded ? 'opacity-100' : 'opacity-0'
}`}
/>
)}
{children && <div className="absolute inset-0 flex items-center justify-center">{children}</div>}
</div>
</div>
</div>
);
};
export default SampleImageCard;

View File

@@ -0,0 +1,190 @@
'use client';
import { useState, useEffect, useMemo } from 'react';
import { createGlobalState } from 'react-global-hooks';
import { Dialog, DialogBackdrop, DialogPanel } from '@headlessui/react';
export interface SampleImageModalState {
imgPath: string;
numSamples: number;
sampleImages: string[];
}
export const sampleImageModalState = createGlobalState<SampleImageModalState | null>(null);
export const openSampleImage = (sampleImageProps: SampleImageModalState) => {
sampleImageModalState.set(sampleImageProps);
};
export default function SampleImageModal() {
const [imageModal, setImageModal] = sampleImageModalState.use();
const [isOpen, setIsOpen] = useState(false);
useEffect(() => {
if (imageModal) {
setIsOpen(true);
}
}, [imageModal]);
useEffect(() => {
if (!isOpen) {
// use timeout to allow the dialog to close before resetting the state
setTimeout(() => {
setImageModal(null);
}, 500);
}
}, [isOpen]);
const onCancel = () => {
setIsOpen(false);
};
const imgInfo = useMemo(() => {
const ii = {
filename: '',
step: 0,
promptIdx: 0,
};
if (imageModal?.imgPath) {
const filename = imageModal.imgPath.split('/').pop();
if (!filename) return ii;
// filename is <timestep>__<zero_pad_step>_<prompt_idx>.<ext>
ii.filename = filename as string;
const parts = filename
.split('.')[0]
.split('_')
.filter(p => p !== '');
if (parts.length === 3) {
ii.step = parseInt(parts[1]);
ii.promptIdx = parseInt(parts[2]);
}
}
return ii;
}, [imageModal]);
const handleArrowUp = () => {
if (!imageModal) return;
console.log('Arrow Up pressed');
// Change image to same sample but up one step
const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath);
if (currentIdx === -1) return;
const nextIdx = currentIdx - imageModal.numSamples;
if (nextIdx < 0) return;
openSampleImage({
imgPath: imageModal.sampleImages[nextIdx],
numSamples: imageModal.numSamples,
sampleImages: imageModal.sampleImages,
});
};
const handleArrowDown = () => {
if (!imageModal) return;
console.log('Arrow Down pressed');
// Change image to same sample but down one step
const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath);
if (currentIdx === -1) return;
const nextIdx = currentIdx + imageModal.numSamples;
if (nextIdx >= imageModal.sampleImages.length) return;
openSampleImage({
imgPath: imageModal.sampleImages[nextIdx],
numSamples: imageModal.numSamples,
sampleImages: imageModal.sampleImages,
});
};
const handleArrowLeft = () => {
if (!imageModal) return;
if (imgInfo.promptIdx === 0) return;
console.log('Arrow Left pressed');
// go to previous sample
const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath);
if (currentIdx === -1) return;
const minIdx = currentIdx - imgInfo.promptIdx;
const nextIdx = currentIdx - 1;
if (nextIdx < minIdx) return;
openSampleImage({
imgPath: imageModal.sampleImages[nextIdx],
numSamples: imageModal.numSamples,
sampleImages: imageModal.sampleImages,
});
};
const handleArrowRight = () => {
if (!imageModal) return;
console.log('Arrow Right pressed');
// go to next sample
const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath);
if (currentIdx === -1) return;
const stepMinIdx = currentIdx - imgInfo.promptIdx;
const maxIdx = stepMinIdx + imageModal.numSamples - 1;
const nextIdx = currentIdx + 1;
if (nextIdx > maxIdx) return;
if (nextIdx >= imageModal.sampleImages.length) return;
openSampleImage({
imgPath: imageModal.sampleImages[nextIdx],
numSamples: imageModal.numSamples,
sampleImages: imageModal.sampleImages,
});
};
// Handle keyboard events
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (!isOpen) return;
switch (event.key) {
case 'Escape':
onCancel();
break;
case 'ArrowUp':
handleArrowUp();
break;
case 'ArrowDown':
handleArrowDown();
break;
case 'ArrowLeft':
handleArrowLeft();
break;
case 'ArrowRight':
handleArrowRight();
break;
default:
break;
}
};
window.addEventListener('keydown', handleKeyDown);
return () => {
window.removeEventListener('keydown', handleKeyDown);
};
}, [isOpen, imageModal, imgInfo]);
return (
<Dialog open={isOpen} onClose={onCancel} className="relative z-10">
<DialogBackdrop
transition
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
/>
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
<div className="flex min-h-full items-center justify-center p-4 text-center">
<DialogPanel
transition
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in max-w-[95%] max-h-[95vh] data-closed:sm:translate-y-0 data-closed:sm:scale-95"
>
<div className="flex justify-center items-center">
{imageModal?.imgPath && (
<img
src={`/api/img/${encodeURIComponent(imageModal.imgPath)}`}
alt="Sample Image"
className="max-w-full max-h-[calc(95vh-2rem)] object-contain"
/>
)}
</div>
<div className="bg-gray-950 text-center text-sm p-2">step: {imgInfo.step}</div>
</DialogPanel>
</div>
</div>
</Dialog>
);
}

View File

@@ -0,0 +1,95 @@
import { useMemo } from 'react';
import useSampleImages from '@/hooks/useSampleImages';
import SampleImageCard from './SampleImageCard';
import { Job } from '@prisma/client';
import { JobConfig } from '@/types';
interface SampleImagesProps {
job: Job;
}
export default function SampleImages({ job }: SampleImagesProps) {
const { sampleImages, status, refreshSampleImages } = useSampleImages(job.id, 5000);
const numSamples = useMemo(() => {
if (job?.job_config) {
const jobConfig = JSON.parse(job.job_config) as JobConfig;
const sampleConfig = jobConfig.config.process[0].sample;
return sampleConfig.prompts.length;
}
return 10;
}, [job]);
// Use direct Tailwind class without string interpolation
// This way Tailwind can properly generate the class
// I hate this, but it's the only way to make it work
const gridColsClass = useMemo(() => {
const cols = Math.min(numSamples, 20);
switch (cols) {
case 1:
return 'grid-cols-1';
case 2:
return 'grid-cols-2';
case 3:
return 'grid-cols-3';
case 4:
return 'grid-cols-4';
case 5:
return 'grid-cols-5';
case 6:
return 'grid-cols-6';
case 7:
return 'grid-cols-7';
case 8:
return 'grid-cols-8';
case 9:
return 'grid-cols-9';
case 10:
return 'grid-cols-10';
case 11:
return 'grid-cols-11';
case 12:
return 'grid-cols-12';
case 13:
return 'grid-cols-13';
case 14:
return 'grid-cols-14';
case 15:
return 'grid-cols-15';
case 16:
return 'grid-cols-16';
case 17:
return 'grid-cols-17';
case 18:
return 'grid-cols-18';
case 19:
return 'grid-cols-19';
case 20:
return 'grid-cols-20';
default:
return 'grid-cols-1';
}
}, [numSamples]);
return (
<div>
<div className="pb-4">
{status === 'loading' && sampleImages.length === 0 && <p>Loading...</p>}
{status === 'error' && <p>Error fetching sample images</p>}
{sampleImages && (
<div className={`grid ${gridColsClass} gap-1`}>
{sampleImages.map((sample: string) => (
<SampleImageCard
key={sample}
imageUrl={sample}
numSamples={numSamples}
sampleImages={sampleImages}
alt="Sample Image"
/>
))}
</div>
)}
</div>
</div>
);
}

View File

@@ -0,0 +1,65 @@
import Link from 'next/link';
import { Home, Settings, BrainCircuit, Images } from 'lucide-react';
const Sidebar = () => {
const navigation = [
{ name: 'Dashboard', href: '/dashboard', icon: Home },
{ name: 'Training Jobs', href: '/jobs', icon: BrainCircuit },
{ name: 'Datasets', href: '/datasets', icon: Images },
{ name: 'Settings', href: '/settings', icon: Settings },
];
return (
<div className="flex flex-col w-64 bg-gray-900 text-gray-100">
<div className="p-4">
<h1 className="text-xl">
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-8 mr-3 inline" />
Ostris - AI Toolkit
</h1>
</div>
<nav className="flex-1">
<ul className="px-2 py-4 space-y-2">
{navigation.map(item => (
<li key={item.name}>
<Link
href={item.href}
className="flex items-center px-4 py-2 text-gray-300 hover:bg-gray-800 rounded-lg transition-colors"
>
<item.icon className="w-5 h-5 mr-3" />
{item.name}
</Link>
</li>
))}
</ul>
</nav>
<a href="https://patreon.com/ostris" target="_blank" rel="noreferrer" className="flex items-center space-x-2 p-4">
<div className='min-w-[26px] min-h-[26px]'>
<svg
viewBox="0 0 512 512"
xmlns="http://www.w3.org/2000/svg"
fillRule="evenodd"
clipRule="evenodd"
strokeLinejoin="round"
strokeMiterlimit="2"
>
<g transform="matrix(.47407 0 0 .47407 .383 .422)">
<clipPath id="prefix__a">
<path d="M0 0h1080v1080H0z"></path>
</clipPath>
<g clipPath="url(#prefix__a)">
<path
d="M1033.05 324.45c-.19-137.9-107.59-250.92-233.6-291.7-156.48-50.64-362.86-43.3-512.28 27.2-181.1 85.46-237.99 272.66-240.11 459.36-1.74 153.5 13.58 557.79 241.62 560.67 169.44 2.15 194.67-216.18 273.07-321.33 55.78-74.81 127.6-95.94 216.01-117.82 151.95-37.61 255.51-157.53 255.29-316.38z"
fillRule="nonzero"
fill="#ffffff"
></path>
</g>
</g>
</svg>
</div>
<div className="text-gray-500 text-md mb-2 flex-1 pt-2 pl-2">Support me on Patreon</div>
</a>
</div>
);
};
export default Sidebar;

View File

@@ -0,0 +1,11 @@
'use client';
import { createContext, useContext, useEffect, useState } from 'react';
const ThemeContext = createContext({ isDark: true });
export const ThemeProvider = ({ children }: { children: React.ReactNode }) => {
const [isDark, setIsDark] = useState(true);
return <ThemeContext.Provider value={{ isDark }}>{children}</ThemeContext.Provider>;
};

View File

@@ -0,0 +1,72 @@
import Loading from './Loading';
import classNames from 'classnames';
export interface TableColumn {
title: string;
key: string;
render?: (row: any) => React.ReactNode;
className?: string;
}
interface TableRow {
[key: string]: any;
}
interface TableProps {
columns: TableColumn[];
rows: TableRow[];
isLoading: boolean;
onRefresh: () => void;
}
export default function UniversalTable({ columns, rows, isLoading, onRefresh = () => {} }: TableProps) {
return (
<div className="w-full bg-gray-900 rounded-md shadow-md">
{isLoading ? (
<div className="p-4 flex justify-center">
<Loading />
</div>
) : rows.length === 0 ? (
<div className="p-6 text-center text-gray-400">
<p className="text-sm">Empty</p>
<button
onClick={() => onRefresh()}
className="mt-2 px-3 py-1 text-xs bg-gray-800 hover:bg-gray-700 text-gray-300 rounded transition-colors"
>
Refresh
</button>
</div>
) : (
<div className="overflow-x-auto">
<table className="w-full text-sm text-left text-gray-300">
<thead className="text-xs uppercase bg-gray-800 text-gray-400">
<tr>
{columns.map(column => (
<th key={column.key} className="px-3 py-2">
{column.title}
</th>
))}
</tr>
</thead>
<tbody>
{rows?.map((row, index) => {
// Style for alternating rows
const rowClass = index % 2 === 0 ? 'bg-gray-900' : 'bg-gray-800';
return (
<tr key={index} className={`${rowClass} border-b border-gray-700 hover:bg-gray-700`}>
{columns.map(column => (
<td key={column.key} className={classNames('px-3 py-2', column.className)}>
{column.render ? column.render(row) : row[column.key]}
</td>
))}
</tr>
);
})}
</tbody>
</table>
</div>
)}
</div>
);
}

View File

@@ -0,0 +1,194 @@
import React from 'react';
import classNames from 'classnames';
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
const inputClasses =
'w-full text-sm px-3 py-1 bg-gray-800 border border-gray-700 rounded-sm focus:ring-2 focus:ring-gray-600 focus:border-transparent';
export interface InputProps {
label?: string;
className?: string;
placeholder?: string;
required?: boolean;
}
export interface TextInputProps extends InputProps {
value: string;
onChange: (value: string) => void;
type?: 'text' | 'password';
disabled?: boolean;
}
export const TextInput = (props: TextInputProps) => {
const { label, value, onChange, placeholder, required, disabled } = props;
return (
<div className={classNames(props.className)}>
{label && <label className={labelClasses}>{label}</label>}
<input
type={props.type || 'text'}
value={value}
onChange={e => {
if (disabled) return;
onChange(e.target.value);
}}
className={`${inputClasses} ${disabled && 'opacity-30 cursor-not-allowed'}`}
placeholder={placeholder}
required={required}
disabled={disabled}
/>
</div>
);
};
export interface NumberInputProps extends InputProps {
value: number;
onChange: (value: number) => void;
min?: number;
max?: number;
}
export const NumberInput = (props: NumberInputProps) => {
const { label, value, onChange, placeholder, required, min, max } = props;
// Add controlled internal state to properly handle partial inputs
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
// Sync internal state with prop value
React.useEffect(() => {
setInputValue(value ?? '');
}, [value]);
return (
<div className={classNames(props.className)}>
{label && <label className={labelClasses}>{label}</label>}
<input
type="number"
value={inputValue}
onChange={e => {
const rawValue = e.target.value;
// Update the input display with the raw value
setInputValue(rawValue);
// Handle empty or partial inputs
if (rawValue === '' || rawValue === '-') {
// For empty or partial negative input, don't call onChange yet
return;
}
const numValue = Number(rawValue);
// Only apply constraints and call onChange when we have a valid number
if (!isNaN(numValue)) {
let constrainedValue = numValue;
// Apply min/max constraints if they exist
if (min !== undefined && constrainedValue < min) {
constrainedValue = min;
}
if (max !== undefined && constrainedValue > max) {
constrainedValue = max;
}
onChange(constrainedValue);
}
}}
className={inputClasses}
placeholder={placeholder}
required={required}
min={min}
max={max}
step="any"
/>
</div>
);
};
export interface SelectInputProps extends InputProps {
value: string;
onChange: (value: string) => void;
options: { value: string; label: string }[];
}
export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options } = props;
return (
<div className={classNames(props.className)}>
{label && <label className={labelClasses}>{label}</label>}
<select value={value} onChange={e => onChange(e.target.value)} className={inputClasses}>
{options.map(option => (
<option key={option.value} value={option.value}>
{option.label}
</option>
))}
</select>
</div>
);
};
export interface CheckboxProps {
label?: string;
checked: boolean;
onChange: (checked: boolean) => void;
className?: string;
required?: boolean;
disabled?: boolean;
}
export const Checkbox = (props: CheckboxProps) => {
const { label, checked, onChange, required, disabled } = props;
const id = React.useId();
return (
<div className={classNames('flex items-center gap-3', props.className)}>
<button
type="button"
role="switch"
id={id}
aria-checked={checked}
aria-required={required}
disabled={disabled}
onClick={() => !disabled && onChange(!checked)}
className={classNames(
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
checked ? 'bg-blue-600' : 'bg-gray-700',
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80'
)}
>
<span className="sr-only">Toggle {label}</span>
<span
className={classNames(
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
checked ? 'translate-x-5' : 'translate-x-0'
)}
/>
</button>
{label && (
<label
htmlFor={id}
className={classNames(
'text-sm font-medium cursor-pointer select-none',
disabled ? 'text-gray-500' : 'text-gray-300'
)}
>
{label}
</label>
)}
</div>
);
};
interface FormGroupProps {
label?: string;
className?: string;
children: React.ReactNode;
}
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children }) => {
return (
<div className={classNames(className)}>
{label && <label className={labelClasses}>{label}</label>}
<div className="px-4 space-y-2">{children}</div>
</div>
);
};

View File

@@ -0,0 +1,27 @@
import classNames from 'classnames';
interface Props {
className?: string;
children?: React.ReactNode;
}
export const TopBar: React.FC<Props> = ({ children, className }) => {
return (
<div
className={classNames(
'absolute top-0 left-0 w-full h-12 dark:bg-gray-900 shadow-sm z-10 flex items-center px-2',
className,
)}
>
{children ? children : null}
</div>
);
};
export const MainContent: React.FC<Props> = ({ children, className }) => {
return (
<div className={classNames('pt-14 px-4 absolute top-0 left-0 w-full h-full overflow-auto', className)}>
{children ? children : null}
</div>
);
};

View File

@@ -0,0 +1,30 @@
'use client';
import { useEffect, useState } from 'react';
export default function useDatasetList() {
const [datasets, setDatasets] = useState<string[]>([]);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
const refreshDatasets = () => {
setStatus('loading');
fetch('/api/datasets/list')
.then(res => res.json())
.then(data => {
console.log('Datasets:', data);
// sort
data.sort((a: string, b: string) => a.localeCompare(b));
setDatasets(data);
setStatus('success');
})
.catch(error => {
console.error('Error fetching datasets:', error);
setStatus('error');
});
};
useEffect(() => {
refreshDatasets();
}, []);
return { datasets, setDatasets, status, refreshDatasets };
}

View File

@@ -0,0 +1,52 @@
'use client';
import { useEffect, useState, useRef } from 'react';
interface FileObject {
path: string;
size: number;
}
export default function useFilesList(jobID: string, reloadInterval: null | number = null) {
const [files, setFiles] = useState<FileObject[]>([]);
const didInitialLoadRef = useRef(false);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle');
const refreshFiles = () => {
let loadStatus: 'loading' | 'refreshing' = 'loading';
if (didInitialLoadRef.current) {
loadStatus = 'refreshing';
}
setStatus(loadStatus);
fetch(`/api/jobs/${jobID}/files`)
.then(res => res.json())
.then(data => {
console.log('Fetched files:', data);
if (data.files) {
setFiles(data.files);
}
setStatus('success');
didInitialLoadRef.current = true;
})
.catch(error => {
console.error('Error fetching datasets:', error);
setStatus('error');
});
};
useEffect(() => {
refreshFiles();
if (reloadInterval) {
const interval = setInterval(() => {
refreshFiles();
}, reloadInterval);
return () => {
clearInterval(interval);
};
}
}, [jobID]);
return { files, setFiles, status, refreshFiles };
}

View File

@@ -0,0 +1,54 @@
'use client';
import { GPUApiResponse, GpuInfo } from '@/types';
import { useEffect, useState } from 'react';
export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterval: null | number = null) {
const [gpuList, setGpuList] = useState<GpuInfo[]>([]);
const [isGPUInfoLoaded, setIsLoaded] = useState(false);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
const fetchGpuInfo = async () => {
setStatus('loading');
try {
const response = await fetch('/api/gpu');
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data: GPUApiResponse = await response.json();
let gpus = data.gpus.sort((a, b) => a.index - b.index);
if (gpuIds) {
gpus = gpus.filter(gpu => gpuIds.includes(gpu.index));
}
setGpuList(gpus);
setStatus('success');
} catch (err) {
console.error(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
setStatus('error');
} finally {
setIsLoaded(true);
}
};
useEffect(() => {
// Fetch immediately on component mount
fetchGpuInfo();
// Set up interval if specified
if (reloadInterval) {
const interval = setInterval(() => {
fetchGpuInfo();
}, reloadInterval);
// Cleanup interval on unmount
return () => {
clearInterval(interval);
};
}
}, [gpuIds, reloadInterval]); // Added dependencies
return { gpuList, setGpuList, isGPUInfoLoaded, status, refreshGpuInfo: fetchGpuInfo };
}

40
ui/src/hooks/useJob.tsx Normal file
View File

@@ -0,0 +1,40 @@
'use client';
import { useEffect, useState } from 'react';
import { Job } from '@prisma/client';
export default function useJob(jobID: string, reloadInterval: null | number = null) {
const [job, setJob] = useState<Job | null>(null);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
const refreshJob = () => {
setStatus('loading');
fetch(`/api/jobs?id=${jobID}`)
.then(res => res.json())
.then(data => {
console.log('Job:', data);
setJob(data);
setStatus('success');
})
.catch(error => {
console.error('Error fetching datasets:', error);
setStatus('error');
});
};
useEffect(() => {
refreshJob();
if (reloadInterval) {
const interval = setInterval(() => {
refreshJob();
}, reloadInterval);
return () => {
clearInterval(interval);
}
}
}, [jobID]);
return { job, setJob, status, refreshJob };
}

View File

@@ -0,0 +1,37 @@
'use client';
import { useEffect, useState } from 'react';
import { Job } from '@prisma/client';
export default function useJobsList(onlyActive = false) {
const [jobs, setJobs] = useState<Job[]>([]);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
const refreshJobs = () => {
setStatus('loading');
fetch('/api/jobs')
.then(res => res.json())
.then(data => {
console.log('Jobs:', data);
if (data.error) {
console.log('Error fetching jobs:', data.error);
setStatus('error');
} else {
if (onlyActive) {
data.jobs = data.jobs.filter((job: Job) => job.status === 'running');
}
setJobs(data.jobs);
setStatus('success');
}
})
.catch(error => {
console.error('Error fetching datasets:', error);
setStatus('error');
});
};
useEffect(() => {
refreshJobs();
}, []);
return { jobs, setJobs, status, refreshJobs };
}

View File

@@ -0,0 +1,40 @@
'use client';
import { useEffect, useState } from 'react';
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
const [sampleImages, setSampleImages] = useState<string[]>([]);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
const refreshSampleImages = () => {
setStatus('loading');
fetch(`/api/jobs/${jobID}/samples`)
.then(res => res.json())
.then(data => {
if (data.samples) {
setSampleImages(data.samples);
}
setStatus('success');
})
.catch(error => {
console.error('Error fetching datasets:', error);
setStatus('error');
});
};
useEffect(() => {
refreshSampleImages();
if (reloadInterval) {
const interval = setInterval(() => {
refreshSampleImages();
}, reloadInterval);
return () => {
clearInterval(interval);
};
}
}, [jobID]);
return { sampleImages, setSampleImages, status, refreshSampleImages };
}

View File

@@ -0,0 +1,28 @@
'use client';
import { useEffect, useState } from 'react';
export default function useSettings() {
const [settings, setSettings] = useState({
HF_TOKEN: '',
TRAINING_FOLDER: '',
DATASETS_FOLDER: '',
});
const [isSettingsLoaded, setIsLoaded] = useState(false);
useEffect(() => {
// Fetch current settings
fetch('/api/settings')
.then(res => res.json())
.then(data => {
setSettings({
HF_TOKEN: data.HF_TOKEN || '',
TRAINING_FOLDER: data.TRAINING_FOLDER || '',
DATASETS_FOLDER: data.DATASETS_FOLDER || '',
});
setIsLoaded(true);
})
.catch(error => console.error('Error fetching settings:', error));
}, []);
return { settings, setSettings, isSettingsLoaded };
}

4
ui/src/paths.ts Normal file
View File

@@ -0,0 +1,4 @@
import path from 'path';
export const TOOLKIT_ROOT = path.resolve('@', '..', '..');
export const defaultTrainFolder = path.join(TOOLKIT_ROOT, 'output');
export const defaultDatasetsFolder = path.join(TOOLKIT_ROOT, 'datasets');

68
ui/src/server/settings.ts Normal file
View File

@@ -0,0 +1,68 @@
import { PrismaClient } from '@prisma/client';
import { defaultDatasetsFolder } from '@/paths';
import { defaultTrainFolder } from '@/paths';
import NodeCache from 'node-cache';
const myCache = new NodeCache();
const prisma = new PrismaClient();
export const flushCache = () => {
myCache.flushAll();
};
export const getDatasetsRoot = async () => {
const key = 'DATASETS_FOLDER';
let datasetsPath = myCache.get(key) as string;
if (datasetsPath) {
return datasetsPath;
}
let row = await prisma.settings.findFirst({
where: {
key: 'DATASETS_FOLDER',
},
});
datasetsPath = defaultDatasetsFolder;
if (row?.value && row.value !== '') {
datasetsPath = row.value;
}
myCache.set(key, datasetsPath);
return datasetsPath as string;
};
export const getTrainingFolder = async () => {
const key = 'TRAINING_FOLDER';
let trainingRoot = myCache.get(key) as string;
if (trainingRoot) {
return trainingRoot;
}
let row = await prisma.settings.findFirst({
where: {
key: key,
},
});
trainingRoot = defaultTrainFolder;
if (row?.value && row.value !== '') {
trainingRoot = row.value;
}
myCache.set(key, trainingRoot);
return trainingRoot as string;
};
export const getHFToken = async () => {
const key = 'HF_TOKEN';
let token = myCache.get(key) as string;
if (token) {
return token;
}
let row = await prisma.settings.findFirst({
where: {
key: key,
},
});
token = '';
if (row?.value && row.value !== '') {
token = row.value;
}
myCache.set(key, token);
return token;
};

160
ui/src/types.ts Normal file
View File

@@ -0,0 +1,160 @@
/**
* GPU API response
*/
export interface GpuUtilization {
gpu: number;
memory: number;
}
export interface GpuMemory {
total: number;
free: number;
used: number;
}
export interface GpuPower {
draw: number;
limit: number;
}
export interface GpuClocks {
graphics: number;
memory: number;
}
export interface GpuFan {
speed: number;
}
export interface GpuInfo {
index: number;
name: string;
driverVersion: string;
temperature: number;
utilization: GpuUtilization;
memory: GpuMemory;
power: GpuPower;
clocks: GpuClocks;
fan: GpuFan;
}
export interface GPUApiResponse {
hasNvidiaSmi: boolean;
gpus: GpuInfo[];
error?: string;
}
/**
* Training configuration
*/
export interface NetworkConfig {
type: 'lora';
linear: number;
linear_alpha: number;
}
export interface SaveConfig {
dtype: string;
save_every: number;
max_step_saves_to_keep: number;
save_format: string;
push_to_hub: boolean;
}
export interface DatasetConfig {
folder_path: string;
mask_path: string | null;
mask_min_value: number;
default_caption: string;
caption_ext: string;
caption_dropout_rate: number;
shuffle_tokens?: boolean;
is_reg: boolean;
network_weight: number;
cache_latents_to_disk?: boolean;
resolution: number[];
}
export interface EMAConfig {
use_ema: boolean;
ema_decay: number;
}
export interface TrainConfig {
batch_size: number;
bypass_guidance_embedding?: boolean;
steps: number;
gradient_accumulation: number;
train_unet: boolean;
train_text_encoder: boolean;
gradient_checkpointing: boolean;
noise_scheduler: string;
timestep_type: string;
content_or_style: string;
optimizer: string;
lr: number;
ema_config?: EMAConfig;
dtype: string;
optimizer_params: {
weight_decay: number;
};
}
export interface QuantizeKwargsConfig {
exclude: string[];
}
export interface ModelConfig {
name_or_path: string;
is_flux?: boolean;
is_lumina2?: boolean;
quantize: boolean;
quantize_te: boolean;
quantize_kwargs?: QuantizeKwargsConfig;
}
export interface SampleConfig {
sampler: string;
sample_every: number;
width: number;
height: number;
prompts: string[];
neg: string;
seed: number;
walk_seed: boolean;
guidance_scale: number;
sample_steps: number;
}
export interface ProcessConfig {
type: 'ui_trainer';
sqlite_db_path?: string;
training_folder: string;
performance_log_every: number;
trigger_word: string | null;
device: string;
network?: NetworkConfig;
save: SaveConfig;
datasets: DatasetConfig[];
train: TrainConfig;
model: ModelConfig;
sample: SampleConfig;
}
export interface ConfigObject {
name: string;
process: ProcessConfig[];
}
export interface MetaConfig {
name: string;
version: string;
}
export interface JobConfig {
job: string;
config: ConfigObject;
meta: MetaConfig;
}

4
ui/src/utils/basic.ts Normal file
View File

@@ -0,0 +1,4 @@
export const objectCopy = <T>(obj: T): T => {
return JSON.parse(JSON.stringify(obj)) as T;
};

88
ui/src/utils/hooks.tsx Normal file
View File

@@ -0,0 +1,88 @@
import React from 'react';
/**
* Updates a deeply nested value in an object using a string path
* @param obj The object to update
* @param value The new value to set
* @param path String path to the property (e.g. 'config.process[0].model.name_or_path')
* @returns A new object with the updated value
*/
export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
// Create a copy of the original object to maintain immutability
const result = { ...obj };
// if path is not provided, be root path
if (!path) {
path = '';
}
// Split the path into segments
const pathArray = path.split('.').flatMap(segment => {
// Handle array notation like 'process[0]'
const arrayMatch = segment.match(/^([^\[]+)(\[\d+\])+/);
if (arrayMatch) {
const propName = arrayMatch[1];
const indices = segment
.substring(propName.length)
.match(/\[(\d+)\]/g)
?.map(idx => parseInt(idx.substring(1, idx.length - 1)));
// Return property name followed by array indices
return [propName, ...(indices || [])];
}
return segment;
});
// Navigate to the target location
let current: any = result;
for (let i = 0; i < pathArray.length - 1; i++) {
const key = pathArray[i];
// If current key is a number, treat it as an array index
if (typeof key === 'number') {
if (!Array.isArray(current)) {
throw new Error(`Cannot access index ${key} of non-array`);
}
// Create a copy of the array to maintain immutability
current = [...current];
} else {
// For object properties, create a new object if it doesn't exist
if (current[key] === undefined) {
// Check if the next key is a number, if so create an array, otherwise an object
const nextKey = pathArray[i + 1];
current[key] = typeof nextKey === 'number' ? [] : {};
} else {
// Create a shallow copy to maintain immutability
current[key] = Array.isArray(current[key]) ? [...current[key]] : { ...current[key] };
}
}
// Move to the next level
current = current[key];
}
// Set the value at the final path segment
const finalKey = pathArray[pathArray.length - 1];
current[finalKey] = value;
return result;
}
/**
* Custom hook for managing a complex state object with string path updates
* @param initialState The initial state object
* @returns [state, setValue] tuple
*/
export function useNestedState<T>(initialState: T): [T, (value: any, path?: string) => void] {
const [state, setState] = React.useState<T>(initialState);
const setValue = React.useCallback((value: any, path?: string) => {
if (path === undefined) {
setState(value);
return
}
setState(prevState => setNestedValue(prevState, value, path));
}, []);
return [state, setValue];
}

75
ui/src/utils/jobs.ts Normal file
View File

@@ -0,0 +1,75 @@
import { JobConfig } from '@/types';
import { Job } from '@prisma/client';
export const startJob = (jobID: string) => {
return new Promise<void>((resolve, reject) => {
fetch(`/api/jobs/${jobID}/start`)
.then(res => res.json())
.then(data => {
console.log('Job started:', data);
resolve();
})
.catch(error => {
console.error('Error starting job:', error);
reject(error);
});
});
};
export const stopJob = (jobID: string) => {
return new Promise<void>((resolve, reject) => {
fetch(`/api/jobs/${jobID}/stop`)
.then(res => res.json())
.then(data => {
console.log('Job stopped:', data);
resolve();
})
.catch(error => {
console.error('Error stopping job:', error);
reject(error);
});
});
};
export const deleteJob = (jobID: string) => {
return new Promise<void>((resolve, reject) => {
fetch(`/api/jobs/${jobID}/delete`)
.then(res => res.json())
.then(data => {
console.log('Job deleted:', data);
resolve();
})
.catch(error => {
console.error('Error deleting job:', error);
reject(error);
});
});
};
export const getJobConfig = (job: Job) => {
return JSON.parse(job.job_config) as JobConfig;
};
export const getAvaliableJobActions = (job: Job) => {
const jobConfig = getJobConfig(job);
const isStopping = job.stop && job.status === 'running';
const canDelete = ['completed', 'stopped', 'error'].includes(job.status) && !isStopping;
const canEdit = ['completed', 'stopped', 'error'].includes(job.status) && !isStopping;
const canStop = job.status === 'running' && !isStopping;
let canStart = ['stopped', 'error'].includes(job.status) && !isStopping;
// can resume if more steps were added
if (job.status === 'completed' && jobConfig.config.process[0].train.steps > job.step && !isStopping) {
canStart = true;
}
return { canDelete, canEdit, canStop, canStart };
};
export const getNumberOfSamples = (job: Job) => {
const jobConfig = getJobConfig(job);
return jobConfig.config.process[0].sample?.prompts?.length || 0;
}
export const getTotalSteps = (job: Job) => {
const jobConfig = getJobConfig(job);
return jobConfig.config.process[0].train.steps;
}

31
ui/tailwind.config.ts Normal file
View File

@@ -0,0 +1,31 @@
import type { Config } from "tailwindcss";
const config: Config = {
content: [
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
],
darkMode: "class",
theme: {
extend: {
colors: {
gray: {
950: "#0a0a0a",
900: "#171717",
800: "#262626",
700: "#404040",
600: "#525252",
500: "#737373",
400: "#a3a3a3",
300: "#d4d4d4",
200: "#e5e5e5",
100: "#f5f5f5",
},
},
},
},
plugins: [],
};
export default config;

27
ui/tsconfig.json Normal file
View File

@@ -0,0 +1,27 @@
{
"compilerOptions": {
"target": "ES2017",
"lib": ["dom", "dom.iterable", "esnext"],
"allowJs": true,
"skipLibCheck": true,
"strict": true,
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
"moduleResolution": "bundler",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
"incremental": true,
"plugins": [
{
"name": "next"
}
],
"paths": {
"@/*": ["./src/*"]
}
},
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
"exclude": ["node_modules"]
}