4
.gitignore
vendored
@@ -161,6 +161,7 @@ cython_debug/
|
||||
|
||||
/env.sh
|
||||
/models
|
||||
/datasets
|
||||
/custom/*
|
||||
!/custom/.gitkeep
|
||||
/.tmp
|
||||
@@ -177,4 +178,5 @@ cython_debug/
|
||||
/wandb
|
||||
.vscode/settings.json
|
||||
.DS_Store
|
||||
._.DS_Store
|
||||
._.DS_Store
|
||||
aitk_db.db
|
||||
223
README.md
@@ -7,7 +7,7 @@
|
||||
</a>
|
||||
|
||||
|
||||
I am transitioning to working on my open source AI projects full time. If you find my work useful, please consider supporting me on [Patreon](https://www.patreon.com/ostris). I will be able to work on more projects and provide better support with your help.
|
||||
I work on open source full time, which means I 100% rely on donations to make a living. If you find this project helpful, or use it in for commercial purposes, please consider donating to support my work on [Patreon](https://www.patreon.com/ostris) or [Github Sponsors](https://github.com/sponsors/ostris).
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -18,7 +18,6 @@ Requirements:
|
||||
- git
|
||||
|
||||
|
||||
|
||||
Linux:
|
||||
```bash
|
||||
git clone https://github.com/ostris/ai-toolkit.git
|
||||
@@ -43,6 +42,43 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
# AI Toolkit UI
|
||||
|
||||
<img src="https://ostris.com/wp-content/uploads/2025/02/toolkit-ui.jpg" alt="AI Toolkit UI" width="100%">
|
||||
|
||||
The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It is still in early beta and will likely have bugs and frequent breaking changes. It is currently only tested on linux for now.
|
||||
|
||||
|
||||
WARNING: The UI is not secure and should not be exposed to the internet. It is only meant to be run locally or on a server that does not have ports exposed. Adding additional security is on the roadmap.
|
||||
|
||||
## Installing the UI
|
||||
|
||||
Requirements:
|
||||
- Node.js > 18
|
||||
|
||||
You will need to do this with every update as well.
|
||||
|
||||
```bash
|
||||
cd ui
|
||||
npm install
|
||||
npm run build
|
||||
npm run update_db
|
||||
```
|
||||
|
||||
## Running the UI
|
||||
|
||||
Make sure you built it as shown above. The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs.
|
||||
|
||||
```bash
|
||||
cd ui
|
||||
npm run start
|
||||
```
|
||||
|
||||
You can now access the UI at `http://localhost:8675` or `http://<your-ip>:8675` if you are running it on a server.
|
||||
|
||||
|
||||
|
||||
## FLUX.1 Training
|
||||
|
||||
### Tutorial
|
||||
@@ -275,186 +311,3 @@ You can also exclude layers by their names by using `ignore_if_contains` network
|
||||
|
||||
`ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
|
||||
if will be ignored.
|
||||
|
||||
---
|
||||
|
||||
## EVERYTHING BELOW THIS LINE IS OUTDATED
|
||||
|
||||
It may still work like that, but I have not tested it in a while.
|
||||
|
||||
---
|
||||
|
||||
### Batch Image Generation
|
||||
|
||||
A image generator that can take frompts from a config file or form a txt file and generate them to a
|
||||
folder. I mainly needed this for an SDXL test I am doing but added some polish to it so it can be used
|
||||
for generat batch image generation.
|
||||
It all runs off a config file, which you can find an example of in `config/examples/generate.example.yaml`.
|
||||
Mere info is in the comments in the example
|
||||
|
||||
---
|
||||
|
||||
### LoRA (lierla), LoCON (LyCORIS) extractor
|
||||
|
||||
It is based on the extractor in the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) tool, but adding some QOL features
|
||||
and LoRA (lierla) support. It can do multiple types of extractions in one run.
|
||||
It all runs off a config file, which you can find an example of in `config/examples/extract.example.yml`.
|
||||
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
|
||||
Then you can edit the file to your liking. and call it like so:
|
||||
|
||||
```bash
|
||||
python3 run.py config/whatever_you_want.yml
|
||||
```
|
||||
|
||||
You can also put a full path to a config file, if you want to keep it somewhere else.
|
||||
|
||||
```bash
|
||||
python3 run.py "/home/user/whatever_you_want.yml"
|
||||
```
|
||||
|
||||
More notes on how it works are available in the example config file itself. LoRA and LoCON both support
|
||||
extractions of 'fixed', 'threshold', 'ratio', 'quantile'. I'll update what these do and mean later.
|
||||
Most people used fixed, which is traditional fixed dimension extraction.
|
||||
|
||||
`process` is an array of different processes to run. You can add a few and mix and match. One LoRA, one LyCON, etc.
|
||||
|
||||
---
|
||||
|
||||
### LoRA Rescale
|
||||
|
||||
Change `<lora:my_lora:4.6>` to `<lora:my_lora:1.0>` or whatever you want with the same effect.
|
||||
A tool for rescaling a LoRA's weights. Should would with LoCON as well, but I have not tested it.
|
||||
It all runs off a config file, which you can find an example of in `config/examples/mod_lora_scale.yml`.
|
||||
Just copy that file, into the `config` folder, and rename it to `whatever_you_want.yml`.
|
||||
Then you can edit the file to your liking. and call it like so:
|
||||
|
||||
```bash
|
||||
python3 run.py config/whatever_you_want.yml
|
||||
```
|
||||
|
||||
You can also put a full path to a config file, if you want to keep it somewhere else.
|
||||
|
||||
```bash
|
||||
python3 run.py "/home/user/whatever_you_want.yml"
|
||||
```
|
||||
|
||||
More notes on how it works are available in the example config file itself. This is useful when making
|
||||
all LoRAs, as the ideal weight is rarely 1.0, but now you can fix that. For sliders, they can have weird scales form -2 to 2
|
||||
or even -15 to 15. This will allow you to dile it in so they all have your desired scale
|
||||
|
||||
---
|
||||
|
||||
### LoRA Slider Trainer
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/ostris/ai-toolkit/blob/main/notebooks/SliderTraining.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
This is how I train most of the recent sliders I have on Civitai, you can check them out in my [Civitai profile](https://civitai.com/user/Ostris/models).
|
||||
It is based off the work by [p1atdev/LECO](https://github.com/p1atdev/LECO) and [rohitgandikota/erasing](https://github.com/rohitgandikota/erasing)
|
||||
But has been heavily modified to create sliders rather than erasing concepts. I have a lot more plans on this, but it is
|
||||
very functional as is. It is also very easy to use. Just copy the example config file in `config/examples/train_slider.example.yml`
|
||||
to the `config` folder and rename it to `whatever_you_want.yml`. Then you can edit the file to your liking. and call it like so:
|
||||
|
||||
```bash
|
||||
python3 run.py config/whatever_you_want.yml
|
||||
```
|
||||
|
||||
There is a lot more information in that example file. You can even run the example as is without any modifications to see
|
||||
how it works. It will create a slider that turns all animals into dogs(neg) or cats(pos). Just run it like so:
|
||||
|
||||
```bash
|
||||
python3 run.py config/examples/train_slider.example.yml
|
||||
```
|
||||
|
||||
And you will be able to see how it works without configuring anything. No datasets are required for this method.
|
||||
I will post an better tutorial soon.
|
||||
|
||||
---
|
||||
|
||||
## Extensions!!
|
||||
|
||||
You can now make and share custom extensions. That run within this framework and have all the inbuilt tools
|
||||
available to them. I will probably use this as the primary development method going
|
||||
forward so I dont keep adding and adding more and more features to this base repo. I will likely migrate a lot
|
||||
of the existing functionality as well to make everything modular. There is an example extension in the `extensions`
|
||||
folder that shows how to make a model merger extension. All of the code is heavily documented which is hopefully
|
||||
enough to get you started. To make an extension, just copy that example and replace all the things you need to.
|
||||
|
||||
|
||||
### Model Merger - Example Extension
|
||||
It is located in the `extensions` folder. It is a fully finctional model merger that can merge as many models together
|
||||
as you want. It is a good example of how to make an extension, but is also a pretty useful feature as well since most
|
||||
mergers can only do one model at a time and this one will take as many as you want to feed it. There is an
|
||||
example config file in there, just copy that to your `config` folder and rename it to `whatever_you_want.yml`.
|
||||
and use it like any other config file.
|
||||
|
||||
## WIP Tools
|
||||
|
||||
|
||||
### VAE (Variational Auto Encoder) Trainer
|
||||
|
||||
This works, but is not ready for others to use and therefore does not have an example config.
|
||||
I am still working on it. I will update this when it is ready.
|
||||
I am adding a lot of features for criteria that I have used in my image enlargement work. A Critic (discriminator),
|
||||
content loss, style loss, and a few more. If you don't know, the VAE
|
||||
for stable diffusion (yes even the MSE one, and SDXL), are horrible at smaller faces and it holds SD back. I will fix this.
|
||||
I'll post more about this later with better examples later, but here is a quick test of a run through with various VAEs.
|
||||
Just went in and out. It is much worse on smaller faces than shown here.
|
||||
|
||||
<img src="https://raw.githubusercontent.com/ostris/ai-toolkit/main/assets/VAE_test1.jpg" width="768" height="auto">
|
||||
|
||||
---
|
||||
|
||||
## TODO
|
||||
- [X] Add proper regs on sliders
|
||||
- [X] Add SDXL support (base model only for now)
|
||||
- [ ] Add plain erasing
|
||||
- [ ] Make Textual inversion network trainer (network that spits out TI embeddings)
|
||||
|
||||
---
|
||||
|
||||
## Change Log
|
||||
|
||||
#### 2023-08-05
|
||||
- Huge memory rework and slider rework. Slider training is better thant ever with no more
|
||||
ram spikes. I also made it so all 4 parts of the slider algorythm run in one batch so they share gradient
|
||||
accumulation. This makes it much faster and more stable.
|
||||
- Updated the example config to be something more practical and more updated to current methods. It is now
|
||||
a detail slide and shows how to train one without a subject. 512x512 slider training for 1.5 should work on
|
||||
6GB gpu now. Will test soon to verify.
|
||||
|
||||
|
||||
#### 2021-10-20
|
||||
- Windows support bug fixes
|
||||
- Extensions! Added functionality to make and share custom extensions for training, merging, whatever.
|
||||
check out the example in the `extensions` folder. Read more about that above.
|
||||
- Model Merging, provided via the example extension.
|
||||
|
||||
#### 2023-08-03
|
||||
Another big refactor to make SD more modular.
|
||||
|
||||
Made batch image generation script
|
||||
|
||||
#### 2023-08-01
|
||||
Major changes and update. New LoRA rescale tool, look above for details. Added better metadata so
|
||||
Automatic1111 knows what the base model is. Added some experiments and a ton of updates. This thing is still unstable
|
||||
at the moment, so hopefully there are not breaking changes.
|
||||
|
||||
Unfortunately, I am too lazy to write a proper changelog with all the changes.
|
||||
|
||||
I added SDXL training to sliders... but.. it does not work properly.
|
||||
The slider training relies on a model's ability to understand that an unconditional (negative prompt)
|
||||
means you do not want that concept in the output. SDXL does not understand this for whatever reason,
|
||||
which makes separating out
|
||||
concepts within the model hard. I am sure the community will find a way to fix this
|
||||
over time, but for now, it is not
|
||||
going to work properly. And if any of you are thinking "Could we maybe fix it by adding 1 or 2 more text
|
||||
encoders to the model as well as a few more entirely separate diffusion networks?" No. God no. It just needs a little
|
||||
training without every experimental new paper added to it. The KISS principal.
|
||||
|
||||
|
||||
#### 2023-07-30
|
||||
Added "anchors" to the slider trainer. This allows you to set a prompt that will be used as a
|
||||
regularizer. You can set the network multiplier to force spread consistency at high weights
|
||||
|
||||
|
||||
227
extensions_built_in/sd_trainer/UITrainer.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sqlite3
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
AITK_Status = Literal["running", "stopped", "error", "completed"]
|
||||
|
||||
|
||||
class UITrainer(SDTrainer):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
|
||||
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
|
||||
print(f"Using SQLite database at {self.sqlite_db_path}")
|
||||
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
||||
if self.job_id is None:
|
||||
raise Exception("AITK_JOB_ID not set")
|
||||
self.is_stopping = False
|
||||
# Create a thread pool for database operations
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
# Track all async tasks
|
||||
self._async_tasks = []
|
||||
# Initialize the status
|
||||
self._run_async_operation(self._update_status("running", "Starting"))
|
||||
|
||||
def _run_async_operation(self, coro):
|
||||
"""Helper method to run an async coroutine and track the task."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Create a task and track it
|
||||
if loop.is_running():
|
||||
task = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
self._async_tasks.append(asyncio.wrap_future(task))
|
||||
else:
|
||||
task = loop.create_task(coro)
|
||||
self._async_tasks.append(task)
|
||||
loop.run_until_complete(task)
|
||||
|
||||
async def _execute_db_operation(self, operation_func):
|
||||
"""Execute a database operation in a separate thread to avoid blocking."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.thread_pool, operation_func)
|
||||
|
||||
def _db_connect(self):
|
||||
"""Create a new connection for each operation to avoid locking."""
|
||||
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
|
||||
conn.isolation_level = None # Enable autocommit mode
|
||||
return conn
|
||||
|
||||
def should_stop(self):
|
||||
def _check_stop():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT stop FROM Job WHERE id = ?", (self.job_id,))
|
||||
stop = cursor.fetchone()
|
||||
return False if stop is None else stop[0] == 1
|
||||
|
||||
return _check_stop()
|
||||
|
||||
def maybe_stop(self):
|
||||
if self.should_stop():
|
||||
self._run_async_operation(
|
||||
self._update_status("stopped", "Job stopped"))
|
||||
self.is_stopping = True
|
||||
raise Exception("Job stopped")
|
||||
|
||||
async def _update_key(self, key, value):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
# Convert the value to string if it's not already
|
||||
if isinstance(value, str):
|
||||
value_to_insert = value
|
||||
else:
|
||||
value_to_insert = str(value)
|
||||
|
||||
# Use parameterized query for both the column name and value
|
||||
update_query = f"UPDATE Job SET {key} = ? WHERE id = ?"
|
||||
cursor.execute(
|
||||
update_query, (value_to_insert, self.job_id))
|
||||
finally:
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_step(self):
|
||||
"""Non-blocking update of the step count."""
|
||||
if self.accelerator.is_main_process:
|
||||
self._run_async_operation(self._update_key("step", self.step_num))
|
||||
|
||||
def update_db_key(self, key, value):
|
||||
"""Non-blocking update a key in the database."""
|
||||
if self.accelerator.is_main_process:
|
||||
self._run_async_operation(self._update_key(key, value))
|
||||
|
||||
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
if info is not None:
|
||||
cursor.execute(
|
||||
"UPDATE Job SET status = ?, info = ? WHERE id = ?",
|
||||
(status, info, self.job_id)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"UPDATE Job SET status = ? WHERE id = ?",
|
||||
(status, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
"""Non-blocking update of status."""
|
||||
if self.accelerator.is_main_process:
|
||||
self._run_async_operation(self._update_status(status, info))
|
||||
|
||||
async def wait_for_all_async(self):
|
||||
"""Wait for all tracked async operations to complete."""
|
||||
if not self._async_tasks:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self._async_tasks)
|
||||
finally:
|
||||
# Clear the task list after completion
|
||||
self._async_tasks.clear()
|
||||
|
||||
def on_error(self, e: Exception):
|
||||
super(UITrainer, self).on_error(e)
|
||||
if self.accelerator.is_main_process and not self.is_stopping:
|
||||
self.update_status("error", str(e))
|
||||
self.update_db_key("step", self.last_save_step)
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def handle_timing_print_hook(self, timing_dict):
|
||||
if "train_loop" not in timing_dict:
|
||||
print("train_loop not found in timing_dict", timing_dict)
|
||||
return
|
||||
seconds_per_iter = timing_dict["train_loop"]
|
||||
# determine iter/sec or sec/iter
|
||||
if seconds_per_iter < 1:
|
||||
iters_per_sec = 1 / seconds_per_iter
|
||||
self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec")
|
||||
else:
|
||||
self.update_db_key(
|
||||
"speed_string", f"{seconds_per_iter:.2f} sec/iter")
|
||||
|
||||
def done_hook(self):
|
||||
super(UITrainer, self).done_hook()
|
||||
self.update_status("completed", "Training completed")
|
||||
# Wait for all async operations to finish before shutting down
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def end_step_hook(self):
|
||||
super(UITrainer, self).end_step_hook()
|
||||
self.update_step()
|
||||
self.maybe_stop()
|
||||
|
||||
def hook_before_model_load(self):
|
||||
super().hook_before_model_load()
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Loading model")
|
||||
|
||||
def before_dataset_load(self):
|
||||
super().before_dataset_load()
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Loading dataset")
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
self.maybe_stop()
|
||||
self.update_step()
|
||||
self.update_status("running", "Training")
|
||||
self.timer.add_after_print_hook(self.handle_timing_print_hook)
|
||||
|
||||
def status_update_hook_func(self, string):
|
||||
self.update_status("running", string)
|
||||
|
||||
def hook_after_sd_init_before_load(self):
|
||||
super().hook_after_sd_init_before_load()
|
||||
self.maybe_stop()
|
||||
self.sd.add_status_update_hook(self.status_update_hook_func)
|
||||
|
||||
def sample_step_hook(self, img_num, total_imgs):
|
||||
super().sample_step_hook(img_num, total_imgs)
|
||||
self.maybe_stop()
|
||||
self.update_status(
|
||||
"running", f"Generating images - {img_num + 1}/{total_imgs}")
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
self.maybe_stop()
|
||||
total_imgs = len(self.sample_config.prompts)
|
||||
self.update_status("running", f"Generating images - 0/{total_imgs}")
|
||||
super().sample(step, is_first)
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
|
||||
def save(self, step=None):
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Saving model")
|
||||
super().save(step)
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
@@ -18,6 +18,22 @@ class SDTrainerExtension(Extension):
|
||||
from .SDTrainer import SDTrainer
|
||||
return SDTrainer
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class UITrainerExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "ui_trainer"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "UI Trainer"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .UITrainer import UITrainer
|
||||
return UITrainer
|
||||
|
||||
|
||||
# for backwards compatability
|
||||
class TextualInversionTrainer(SDTrainerExtension):
|
||||
@@ -26,5 +42,5 @@ class TextualInversionTrainer(SDTrainerExtension):
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
SDTrainerExtension, TextualInversionTrainer
|
||||
SDTrainerExtension, TextualInversionTrainer, UITrainerExtension
|
||||
]
|
||||
|
||||
@@ -24,6 +24,9 @@ class BaseProcess(object):
|
||||
self.performance_log_every = self.get_conf('performance_log_every', 0)
|
||||
|
||||
print(json.dumps(self.config, indent=4))
|
||||
|
||||
def on_error(self, e: Exception):
|
||||
pass
|
||||
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
# split key by '.' and recursively get the value
|
||||
|
||||
@@ -92,6 +92,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.epoch_num = 0
|
||||
self.last_save_step = 0
|
||||
# start at 1 so we can do a sample at the start
|
||||
self.grad_accumulation_step = 1
|
||||
# if true, then we do not do an optimizer step. We are accumulating gradients
|
||||
@@ -439,6 +440,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def post_save_hook(self, save_path):
|
||||
# override in subclass
|
||||
pass
|
||||
|
||||
def done_hook(self):
|
||||
pass
|
||||
|
||||
def end_step_hook(self):
|
||||
pass
|
||||
|
||||
def save(self, step=None):
|
||||
if not self.accelerator.is_main_process:
|
||||
@@ -453,6 +460,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
self.last_save_step = step
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
|
||||
@@ -648,6 +656,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.logger.start()
|
||||
self.prepare_accelerator()
|
||||
|
||||
def sample_step_hook(self, img_num, total_imgs):
|
||||
pass
|
||||
|
||||
def prepare_accelerator(self):
|
||||
# set some config
|
||||
@@ -722,6 +732,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def hook_train_loop(self, batch):
|
||||
# return loss
|
||||
return 0.0
|
||||
|
||||
def hook_after_sd_init_before_load(self):
|
||||
pass
|
||||
|
||||
def get_latest_save_path(self, name=None, post=''):
|
||||
if name == None:
|
||||
@@ -1417,8 +1430,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
custom_pipeline=self.custom_pipeline,
|
||||
noise_scheduler=sampler,
|
||||
)
|
||||
|
||||
self.hook_after_sd_init_before_load()
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
self.sd.add_after_sample_image_hook(self.sample_step_hook)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
@@ -1812,6 +1829,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd)
|
||||
|
||||
flush()
|
||||
self.last_save_step = self.step_num
|
||||
### HOOK ###
|
||||
self.hook_before_train_loop()
|
||||
|
||||
@@ -2091,6 +2109,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# update various steps
|
||||
self.step_num = step + 1
|
||||
self.grad_accumulation_step += 1
|
||||
self.end_step_hook()
|
||||
|
||||
|
||||
###################################################################
|
||||
@@ -2110,13 +2129,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.logger.finish()
|
||||
self.accelerator.end_training()
|
||||
|
||||
if self.save_config.push_to_hub:
|
||||
if("HF_TOKEN" not in os.environ):
|
||||
interpreter_login(new_session=False, write_permission=True)
|
||||
self.push_to_hub(
|
||||
repo_id=self.save_config.hf_repo_id,
|
||||
private=self.save_config.hf_private
|
||||
)
|
||||
if self.accelerator.is_main_process:
|
||||
# push to hub
|
||||
if self.save_config.push_to_hub:
|
||||
if("HF_TOKEN" not in os.environ):
|
||||
interpreter_login(new_session=False, write_permission=True)
|
||||
self.push_to_hub(
|
||||
repo_id=self.save_config.hf_repo_id,
|
||||
private=self.save_config.hf_private
|
||||
)
|
||||
del (
|
||||
self.sd,
|
||||
unet,
|
||||
@@ -2128,6 +2149,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
flush()
|
||||
self.done_hook()
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
|
||||
@@ -32,4 +32,5 @@ sentencepiece
|
||||
huggingface_hub
|
||||
peft
|
||||
gradio
|
||||
python-slugify
|
||||
python-slugify
|
||||
sqlite3
|
||||
4
run.py
@@ -88,6 +88,10 @@ def main():
|
||||
except Exception as e:
|
||||
print_acc(f"Error running job: {e}")
|
||||
jobs_failed += 1
|
||||
try:
|
||||
job.process[0].on_error(e)
|
||||
except Exception as e2:
|
||||
print_acc(f"Error running on_error: {e2}")
|
||||
if not args.recover:
|
||||
print_end_message(jobs_completed, jobs_failed)
|
||||
raise e
|
||||
|
||||
@@ -379,7 +379,8 @@ class TrainConfig:
|
||||
self.do_prior_divergence = kwargs.get('do_prior_divergence', False)
|
||||
|
||||
ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
|
||||
if ema_config is not None:
|
||||
# if it is set explicitly to false, leave it false.
|
||||
if ema_config is not None and ema_config.get('use_ema', None) is not None:
|
||||
ema_config['use_ema'] = True
|
||||
print(f"Using EMA")
|
||||
else:
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import List, Optional, Dict, Type, Union
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
|
||||
from transformers import CLIPTextModel
|
||||
from toolkit.models.lokr import LokrModule
|
||||
|
||||
from .config_modules import NetworkConfig
|
||||
from .lorm import count_parameters
|
||||
|
||||
282
toolkit/models/lokr.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# based heavily on https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from toolkit.network_mixins import ToolkitModuleMixin
|
||||
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
|
||||
# 4, build custom backward function
|
||||
# -
|
||||
|
||||
|
||||
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
'''
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||
secon value is a value for weight.
|
||||
|
||||
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||
|
||||
examples)
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1 127 -> 127, 1
|
||||
128 -> 16, 8 128 -> 64, 2 128 -> 32, 4 128 -> 16, 8 128 -> 16, 8
|
||||
250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2 250 -> 125, 2
|
||||
360 -> 45, 8 360 -> 180, 2 360 -> 90, 4 360 -> 45, 8 360 -> 45, 8
|
||||
512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16
|
||||
1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16
|
||||
'''
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
return m, n
|
||||
if factor == -1:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m<n:
|
||||
new_m = m + 1
|
||||
while dimension%new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m>factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
|
||||
def make_weight_cp(t, wa, wb):
|
||||
rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2]
|
||||
return rebuild2
|
||||
|
||||
|
||||
def make_kron(w1, w2, scale):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
rebuild = torch.kron(w1, w2)
|
||||
|
||||
return rebuild*scale
|
||||
|
||||
|
||||
class LokrModule(ToolkitModuleMixin, nn.Module):
|
||||
"""
|
||||
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
||||
and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule
|
||||
and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=0.,
|
||||
rank_dropout=0.,
|
||||
module_dropout=0.,
|
||||
use_cp=False,
|
||||
decompose_both = False,
|
||||
network: 'LoRASpecialNetwork' = None,
|
||||
factor:int=-1, # factorization factor
|
||||
**kwargs,
|
||||
):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
ToolkitModuleMixin.__init__(self, network=network)
|
||||
torch.nn.Module.__init__(self)
|
||||
factor = int(factor)
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.cp = False
|
||||
self.use_w1 = False
|
||||
self.use_w2 = False
|
||||
|
||||
self.shape = org_module.weight.shape
|
||||
if org_module.__class__.__name__ == 'Conv2d':
|
||||
in_dim = org_module.in_channels
|
||||
k_size = org_module.kernel_size
|
||||
out_dim = org_module.out_channels
|
||||
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
|
||||
|
||||
self.cp = use_cp and k_size!=(1, 1)
|
||||
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
||||
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
||||
else:
|
||||
self.use_w1 = True
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
|
||||
if lora_dim >= max(shape[0][1], shape[1][1])/2:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size))
|
||||
elif self.cp:
|
||||
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode
|
||||
else: # Conv2d not cp
|
||||
# bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2]
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3]))
|
||||
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2)
|
||||
|
||||
self.op = F.conv2d
|
||||
self.extra_args = {
|
||||
"stride": org_module.stride,
|
||||
"padding": org_module.padding,
|
||||
"dilation": org_module.dilation,
|
||||
"groups": org_module.groups
|
||||
}
|
||||
|
||||
else: # Linear
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
|
||||
|
||||
# smaller part. weight scale
|
||||
if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2:
|
||||
self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim))
|
||||
self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0]))
|
||||
else:
|
||||
self.use_w1 = True
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode
|
||||
|
||||
if lora_dim < max(shape[0][1], shape[1][1])/2:
|
||||
# bigger part. weight and LoRA. [b, dim] x [dim, d]
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]))
|
||||
# w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd)
|
||||
else:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
|
||||
|
||||
self.op = F.linear
|
||||
self.extra_args = {}
|
||||
|
||||
self.dropout = dropout
|
||||
if dropout:
|
||||
print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.")
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
if self.use_w2 and self.use_w1:
|
||||
#use scale = 1
|
||||
alpha = lora_dim
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
if self.use_w2:
|
||||
torch.nn.init.constant_(self.lokr_w2, 0)
|
||||
else:
|
||||
if self.cp:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
||||
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
||||
|
||||
if self.use_w1:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
||||
else:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5))
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = [org_module]
|
||||
weight = make_kron(
|
||||
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
else self.lokr_w2_a@self.lokr_w2_b),
|
||||
torch.tensor(self.multiplier * self.scale)
|
||||
)
|
||||
assert torch.sum(torch.isnan(weight)) == 0, "weight is nan"
|
||||
|
||||
# Same as locon.py
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def get_weight(self, orig_weight = None):
|
||||
weight = make_kron(
|
||||
self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b,
|
||||
(self.lokr_w2 if self.use_w2
|
||||
else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp
|
||||
else self.lokr_w2_a@self.lokr_w2_b),
|
||||
torch.tensor(self.scale)
|
||||
)
|
||||
if orig_weight is not None:
|
||||
weight = weight.reshape(orig_weight.shape)
|
||||
if self.training and self.rank_dropout:
|
||||
drop = torch.rand(weight.size(0)) < self.rank_dropout
|
||||
weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device)
|
||||
return weight
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_max_norm(self, max_norm, device=None):
|
||||
orig_norm = self.get_weight().norm()
|
||||
norm = torch.clamp(orig_norm, max_norm/2)
|
||||
desired = torch.clamp(norm, max=max_norm)
|
||||
ratio = desired.cpu()/norm.cpu()
|
||||
|
||||
scaled = ratio.item() != 1.0
|
||||
if scaled:
|
||||
modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp))
|
||||
if self.use_w1:
|
||||
self.lokr_w1 *= ratio**(1/modules)
|
||||
else:
|
||||
self.lokr_w1_a *= ratio**(1/modules)
|
||||
self.lokr_w1_b *= ratio**(1/modules)
|
||||
|
||||
if self.use_w2:
|
||||
self.lokr_w2 *= ratio**(1/modules)
|
||||
else:
|
||||
if self.cp:
|
||||
self.lokr_t2 *= ratio**(1/modules)
|
||||
self.lokr_w2_a *= ratio**(1/modules)
|
||||
self.lokr_w2_b *= ratio**(1/modules)
|
||||
|
||||
return scaled, orig_norm*ratio
|
||||
|
||||
def forward(self, x):
|
||||
if self.module_dropout and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return self.op(
|
||||
x,
|
||||
self.org_module[0].weight.data,
|
||||
None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
||||
)
|
||||
weight = (
|
||||
self.org_module[0].weight.data
|
||||
+ self.get_weight(self.org_module[0].weight.data) * self.multiplier
|
||||
)
|
||||
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
||||
return self.op(
|
||||
x,
|
||||
weight.view(self.shape),
|
||||
bias,
|
||||
**self.extra_args
|
||||
)
|
||||
@@ -202,6 +202,8 @@ class StableDiffusion:
|
||||
|
||||
# merge in and preview active with -1 weight
|
||||
self.invert_assistant_lora = False
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
@@ -540,10 +542,10 @@ class StableDiffusion:
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
elif self.model_config.is_flux:
|
||||
print_acc("Loading Flux model")
|
||||
self.print_and_status_update("Loading Flux model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
print_acc("Loading transformer")
|
||||
self.print_and_status_update("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
local_files_only = False
|
||||
@@ -688,7 +690,7 @@ class StableDiffusion:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print_acc("Quantizing transformer")
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
@@ -698,7 +700,7 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print_acc("Loading vae")
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
@@ -707,7 +709,7 @@ class StableDiffusion:
|
||||
text_encoder_2 = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
|
||||
else:
|
||||
print_acc("Loading t5")
|
||||
self.print_and_status_update("Loading T5")
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
|
||||
torch_dtype=dtype)
|
||||
@@ -717,19 +719,19 @@ class StableDiffusion:
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
if self.is_flex2:
|
||||
print_acc("Quantizing LLM")
|
||||
self.print_and_status_update("Quantizing LLM")
|
||||
else:
|
||||
print_acc("Quantizing T5")
|
||||
self.print_and_status_update("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
print_acc("Loading clip")
|
||||
self.print_and_status_update("Loading CLIP")
|
||||
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
print_acc("making pipe")
|
||||
self.print_and_status_update("Making pipe")
|
||||
Pipe = FluxPipeline
|
||||
if self.is_flex2:
|
||||
Pipe = Flex2Pipeline
|
||||
@@ -746,7 +748,7 @@ class StableDiffusion:
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.transformer = transformer
|
||||
|
||||
print_acc("preparing")
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
@@ -763,10 +765,10 @@ class StableDiffusion:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
elif self.model_config.is_lumina2:
|
||||
print_acc("Loading Lumina2 model")
|
||||
self.print_and_status_update("Loading Lumina2 model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
print_acc("Loading transformer")
|
||||
self.print_and_status_update("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
@@ -802,7 +804,7 @@ class StableDiffusion:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print_acc("Quantizing transformer")
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
@@ -812,16 +814,16 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print_acc("Loading vae")
|
||||
self.print_and_status_update("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.te_name_or_path is not None:
|
||||
print_acc("Loading TE")
|
||||
self.print_and_status_update("Loading TE")
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
|
||||
text_encoder = AutoModel.from_pretrained(self.model_config.te_name_or_path, torch_dtype=dtype)
|
||||
else:
|
||||
print_acc("Loading Gemma2")
|
||||
self.print_and_status_update("Loading Gemma2")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
|
||||
@@ -829,12 +831,12 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
print_acc("Quantizing Gemma2")
|
||||
self.print_and_status_update("Quantizing Gemma2")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
print_acc("making pipe")
|
||||
self.print_and_status_update("Making pipe")
|
||||
pipe: Lumina2Text2ImgPipeline = Lumina2Text2ImgPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=None,
|
||||
@@ -845,7 +847,7 @@ class StableDiffusion:
|
||||
pipe.text_encoder = text_encoder
|
||||
pipe.transformer = transformer
|
||||
|
||||
print_acc("preparing")
|
||||
self.print_and_status_update("Preparing Model")
|
||||
|
||||
text_encoder = pipe.text_encoder
|
||||
tokenizer = pipe.tokenizer
|
||||
@@ -1032,6 +1034,25 @@ class StableDiffusion:
|
||||
self.refiner_unet = refiner.unet
|
||||
del refiner
|
||||
flush()
|
||||
|
||||
def _after_sample_image(self, img_num, total_imgs):
|
||||
# process all hooks
|
||||
for hook in self._after_sample_img_hooks:
|
||||
hook(img_num, total_imgs)
|
||||
|
||||
def add_after_sample_image_hook(self, func):
|
||||
self._after_sample_img_hooks.append(func)
|
||||
|
||||
def _status_update(self, status: str):
|
||||
for hook in self._status_update_hooks:
|
||||
hook(status)
|
||||
|
||||
def print_and_status_update(self, status: str):
|
||||
print_acc(status)
|
||||
self._status_update(status)
|
||||
|
||||
def add_status_update_hook(self, func):
|
||||
self._status_update_hooks.append(func)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_images(
|
||||
@@ -1598,6 +1619,7 @@ class StableDiffusion:
|
||||
|
||||
gen_config.save_image(img, i)
|
||||
gen_config.log_image(img, i)
|
||||
self._after_sample_image(i, len(image_configs))
|
||||
flush()
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
||||
|
||||
@@ -9,6 +9,7 @@ class Timer:
|
||||
self.timers = OrderedDict()
|
||||
self.active_timers = {}
|
||||
self.current_timer = None # Used for the context manager functionality
|
||||
self._after_print_hooks = []
|
||||
|
||||
def start(self, timer_name):
|
||||
if timer_name not in self.timers:
|
||||
@@ -34,12 +35,20 @@ class Timer:
|
||||
if len(self.timers[timer_name]) > self.max_buffer:
|
||||
self.timers[timer_name].popleft()
|
||||
|
||||
def add_after_print_hook(self, hook):
|
||||
self._after_print_hooks.append(hook)
|
||||
|
||||
def print(self):
|
||||
print(f"\nTimer '{self.name}':")
|
||||
timing_dict = {}
|
||||
# sort by longest at top
|
||||
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
|
||||
avg_time = sum(timings) / len(timings)
|
||||
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
|
||||
timing_dict[timer_name] = avg_time
|
||||
|
||||
for hook in self._after_print_hooks:
|
||||
hook(timing_dict)
|
||||
|
||||
print('')
|
||||
|
||||
|
||||
42
ui/.gitignore
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||
|
||||
# dependencies
|
||||
/node_modules
|
||||
/.pnp
|
||||
.pnp.*
|
||||
.yarn/*
|
||||
!.yarn/patches
|
||||
!.yarn/plugins
|
||||
!.yarn/releases
|
||||
!.yarn/versions
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
|
||||
# next.js
|
||||
/.next/
|
||||
/out/
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
*.pem
|
||||
|
||||
# debug
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.pnpm-debug.log*
|
||||
|
||||
# env files (can opt-in for committing if needed)
|
||||
.env*
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
aitk_db.db
|
||||
36
ui/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
|
||||
|
||||
## Getting Started
|
||||
|
||||
First, run the development server:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
# or
|
||||
yarn dev
|
||||
# or
|
||||
pnpm dev
|
||||
# or
|
||||
bun dev
|
||||
```
|
||||
|
||||
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
|
||||
|
||||
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
|
||||
|
||||
This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
|
||||
|
||||
## Learn More
|
||||
|
||||
To learn more about Next.js, take a look at the following resources:
|
||||
|
||||
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
|
||||
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
|
||||
|
||||
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
|
||||
|
||||
## Deploy on Vercel
|
||||
|
||||
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
|
||||
|
||||
Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
|
||||
15
ui/next.config.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import type { NextConfig } from 'next';
|
||||
|
||||
const nextConfig: NextConfig = {
|
||||
typescript: {
|
||||
// Remove this. Build fails because of route types
|
||||
ignoreBuildErrors: true,
|
||||
},
|
||||
experimental: {
|
||||
serverActions: {
|
||||
bodySizeLimit: '100mb',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export default nextConfig;
|
||||
4516
ui/package-lock.json
generated
Normal file
40
ui/package.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"name": "ai-toolkit-ui",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbopack",
|
||||
"build": "next build",
|
||||
"start": "next start --port 8675",
|
||||
"lint": "next lint",
|
||||
"update_db": "npx prisma generate ; npx prisma db push",
|
||||
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
|
||||
},
|
||||
"dependencies": {
|
||||
"@headlessui/react": "^2.2.0",
|
||||
"@prisma/client": "^6.3.1",
|
||||
"axios": "^1.7.9",
|
||||
"classnames": "^2.5.1",
|
||||
"lucide-react": "^0.475.0",
|
||||
"next": "15.1.7",
|
||||
"node-cache": "^5.1.2",
|
||||
"prisma": "^6.3.1",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-dropzone": "^14.3.5",
|
||||
"react-global-hooks": "^1.3.5",
|
||||
"react-icons": "^5.5.0",
|
||||
"sqlite3": "^5.1.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"postcss": "^8",
|
||||
"prettier": "^3.5.1",
|
||||
"prettier-basic": "^1.0.0",
|
||||
"tailwindcss": "^3.4.1",
|
||||
"typescript": "^5"
|
||||
},
|
||||
"prettier": "prettier-basic"
|
||||
}
|
||||
8
ui/postcss.config.mjs
Normal file
@@ -0,0 +1,8 @@
|
||||
/** @type {import('postcss-load-config').Config} */
|
||||
const config = {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
},
|
||||
};
|
||||
|
||||
export default config;
|
||||
28
ui/prisma/schema.prisma
Normal file
@@ -0,0 +1,28 @@
|
||||
generator client {
|
||||
provider = "prisma-client-js"
|
||||
}
|
||||
|
||||
datasource db {
|
||||
provider = "sqlite"
|
||||
url = "file:../../aitk_db.db"
|
||||
}
|
||||
|
||||
model Settings {
|
||||
id Int @id @default(autoincrement())
|
||||
key String @unique
|
||||
value String
|
||||
}
|
||||
|
||||
model Job {
|
||||
id String @id @default(uuid())
|
||||
name String @unique
|
||||
gpu_ids String
|
||||
job_config String // JSON string
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @updatedAt
|
||||
status String @default("stopped")
|
||||
stop Boolean @default(false)
|
||||
step Int @default(0)
|
||||
info String @default("")
|
||||
speed_string String @default("")
|
||||
}
|
||||
1
ui/public/file.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg fill="none" viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg"><path d="M14.5 13.5V5.41a1 1 0 0 0-.3-.7L9.8.29A1 1 0 0 0 9.08 0H1.5v13.5A2.5 2.5 0 0 0 4 16h8a2.5 2.5 0 0 0 2.5-2.5m-1.5 0v-7H8v-5H3v12a1 1 0 0 0 1 1h8a1 1 0 0 0 1-1M9.5 5V2.12L12.38 5zM5.13 5h-.62v1.25h2.12V5zm-.62 3h7.12v1.25H4.5zm.62 3h-.62v1.25h7.12V11z" clip-rule="evenodd" fill="#666" fill-rule="evenodd"/></svg>
|
||||
|
After Width: | Height: | Size: 391 B |
1
ui/public/globe.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><g clip-path="url(#a)"><path fill-rule="evenodd" clip-rule="evenodd" d="M10.27 14.1a6.5 6.5 0 0 0 3.67-3.45q-1.24.21-2.7.34-.31 1.83-.97 3.1M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16m.48-1.52a7 7 0 0 1-.96 0H7.5a4 4 0 0 1-.84-1.32q-.38-.89-.63-2.08a40 40 0 0 0 3.92 0q-.25 1.2-.63 2.08a4 4 0 0 1-.84 1.31zm2.94-4.76q1.66-.15 2.95-.43a7 7 0 0 0 0-2.58q-1.3-.27-2.95-.43a18 18 0 0 1 0 3.44m-1.27-3.54a17 17 0 0 1 0 3.64 39 39 0 0 1-4.3 0 17 17 0 0 1 0-3.64 39 39 0 0 1 4.3 0m1.1-1.17q1.45.13 2.69.34a6.5 6.5 0 0 0-3.67-3.44q.65 1.26.98 3.1M8.48 1.5l.01.02q.41.37.84 1.31.38.89.63 2.08a40 40 0 0 0-3.92 0q.25-1.2.63-2.08a4 4 0 0 1 .85-1.32 7 7 0 0 1 .96 0m-2.75.4a6.5 6.5 0 0 0-3.67 3.44 29 29 0 0 1 2.7-.34q.31-1.83.97-3.1M4.58 6.28q-1.66.16-2.95.43a7 7 0 0 0 0 2.58q1.3.27 2.95.43a18 18 0 0 1 0-3.44m.17 4.71q-1.45-.12-2.69-.34a6.5 6.5 0 0 0 3.67 3.44q-.65-1.27-.98-3.1" fill="#666"/></g><defs><clipPath id="a"><path fill="#fff" d="M0 0h16v16H0z"/></clipPath></defs></svg>
|
||||
|
After Width: | Height: | Size: 1.0 KiB |
1
ui/public/next.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 394 80"><path fill="#000" d="M262 0h68.5v12.7h-27.2v66.6h-13.6V12.7H262V0ZM149 0v12.7H94v20.4h44.3v12.6H94v21h55v12.6H80.5V0h68.7zm34.3 0h-17.8l63.8 79.4h17.9l-32-39.7 32-39.6h-17.9l-23 28.6-23-28.6zm18.3 56.7-9-11-27.1 33.7h17.8l18.3-22.7z"/><path fill="#000" d="M81 79.3 17 0H0v79.3h13.6V17l50.2 62.3H81Zm252.6-.4c-1 0-1.8-.4-2.5-1s-1.1-1.6-1.1-2.6.3-1.8 1-2.5 1.6-1 2.6-1 1.8.3 2.5 1a3.4 3.4 0 0 1 .6 4.3 3.7 3.7 0 0 1-3 1.8zm23.2-33.5h6v23.3c0 2.1-.4 4-1.3 5.5a9.1 9.1 0 0 1-3.8 3.5c-1.6.8-3.5 1.3-5.7 1.3-2 0-3.7-.4-5.3-1s-2.8-1.8-3.7-3.2c-.9-1.3-1.4-3-1.4-5h6c.1.8.3 1.6.7 2.2s1 1.2 1.6 1.5c.7.4 1.5.5 2.4.5 1 0 1.8-.2 2.4-.6a4 4 0 0 0 1.6-1.8c.3-.8.5-1.8.5-3V45.5zm30.9 9.1a4.4 4.4 0 0 0-2-3.3 7.5 7.5 0 0 0-4.3-1.1c-1.3 0-2.4.2-3.3.5-.9.4-1.6 1-2 1.6a3.5 3.5 0 0 0-.3 4c.3.5.7.9 1.3 1.2l1.8 1 2 .5 3.2.8c1.3.3 2.5.7 3.7 1.2a13 13 0 0 1 3.2 1.8 8.1 8.1 0 0 1 3 6.5c0 2-.5 3.7-1.5 5.1a10 10 0 0 1-4.4 3.5c-1.8.8-4.1 1.2-6.8 1.2-2.6 0-4.9-.4-6.8-1.2-2-.8-3.4-2-4.5-3.5a10 10 0 0 1-1.7-5.6h6a5 5 0 0 0 3.5 4.6c1 .4 2.2.6 3.4.6 1.3 0 2.5-.2 3.5-.6 1-.4 1.8-1 2.4-1.7a4 4 0 0 0 .8-2.4c0-.9-.2-1.6-.7-2.2a11 11 0 0 0-2.1-1.4l-3.2-1-3.8-1c-2.8-.7-5-1.7-6.6-3.2a7.2 7.2 0 0 1-2.4-5.7 8 8 0 0 1 1.7-5 10 10 0 0 1 4.3-3.5c2-.8 4-1.2 6.4-1.2 2.3 0 4.4.4 6.2 1.2 1.8.8 3.2 2 4.3 3.4 1 1.4 1.5 3 1.5 5h-5.8z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
BIN
ui/public/ostris_logo.png
Normal file
|
After Width: | Height: | Size: 22 KiB |
1
ui/public/vercel.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1155 1000"><path d="m577.3 0 577.4 1000H0z" fill="#fff"/></svg>
|
||||
|
After Width: | Height: | Size: 128 B |
BIN
ui/public/web-app-manifest-192x192.png
Normal file
|
After Width: | Height: | Size: 9.7 KiB |
BIN
ui/public/web-app-manifest-512x512.png
Normal file
|
After Width: | Height: | Size: 37 KiB |
1
ui/public/window.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg fill="none" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path fill-rule="evenodd" clip-rule="evenodd" d="M1.5 2.5h13v10a1 1 0 0 1-1 1h-11a1 1 0 0 1-1-1zM0 1h16v11.5a2.5 2.5 0 0 1-2.5 2.5h-11A2.5 2.5 0 0 1 0 12.5zm3.75 4.5a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5M7 4.75a.75.75 0 1 1-1.5 0 .75.75 0 0 1 1.5 0m1.75.75a.75.75 0 1 0 0-1.5.75.75 0 0 0 0 1.5" fill="#666"/></svg>
|
||||
|
After Width: | Height: | Size: 385 B |
42
ui/src/app/api/caption/[...imagePath]/route.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
/* eslint-disable */
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
|
||||
const { imagePath } = await params;
|
||||
try {
|
||||
// Decode the path
|
||||
const filepath = decodeURIComponent(imagePath);
|
||||
|
||||
// caption name is the filepath without extension but with .txt
|
||||
const captionPath = filepath.replace(/\.[^/.]+$/, '') + '.txt';
|
||||
|
||||
// Get allowed directories
|
||||
const allowedDir = await getDatasetsRoot();
|
||||
|
||||
// Security check: Ensure path is in allowed directory
|
||||
const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..');
|
||||
|
||||
if (!isAllowed) {
|
||||
console.warn(`Access denied: ${filepath} not in ${allowedDir}`);
|
||||
return new NextResponse('Access denied', { status: 403 });
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!fs.existsSync(captionPath)) {
|
||||
// send back blank string if caption file does not exist
|
||||
return new NextResponse('');
|
||||
}
|
||||
|
||||
// Read caption file
|
||||
const caption = fs.readFileSync(captionPath, 'utf-8');
|
||||
|
||||
// Return caption
|
||||
return new NextResponse(caption);
|
||||
} catch (error) {
|
||||
console.error('Error getting caption:', error);
|
||||
return new NextResponse('Error getting caption', { status: 500 });
|
||||
}
|
||||
}
|
||||
22
ui/src/app/api/datasets/create/route.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { name } = body;
|
||||
let datasetsPath = await getDatasetsRoot();
|
||||
let datasetPath = path.join(datasetsPath, name);
|
||||
|
||||
// if folder doesnt exist, create it
|
||||
if (!fs.existsSync(datasetPath)) {
|
||||
fs.mkdirSync(datasetPath);
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true });
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
24
ui/src/app/api/datasets/delete/route.tsx
Normal file
@@ -0,0 +1,24 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { name } = body;
|
||||
let datasetsPath = await getDatasetsRoot();
|
||||
let datasetPath = path.join(datasetsPath, name);
|
||||
|
||||
// if folder doesnt exist, ignore
|
||||
if (!fs.existsSync(datasetPath)) {
|
||||
return NextResponse.json({ success: true });
|
||||
}
|
||||
|
||||
// delete it and return success
|
||||
fs.rmdirSync(datasetPath, { recursive: true });
|
||||
return NextResponse.json({ success: true });
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
25
ui/src/app/api/datasets/list/route.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
let datasetsPath = await getDatasetsRoot();
|
||||
|
||||
// if folder doesnt exist, create it
|
||||
if (!fs.existsSync(datasetsPath)) {
|
||||
fs.mkdirSync(datasetsPath);
|
||||
}
|
||||
|
||||
// find all the folders in the datasets folder
|
||||
let folders = fs
|
||||
.readdirSync(datasetsPath, { withFileTypes: true })
|
||||
.filter(dirent => dirent.isDirectory())
|
||||
.filter(dirent => !dirent.name.startsWith('.'))
|
||||
.map(dirent => dirent.name);
|
||||
|
||||
return NextResponse.json(folders);
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to fetch datasets' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
67
ui/src/app/api/datasets/listImages/route.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const datasetsPath = await getDatasetsRoot();
|
||||
const body = await request.json();
|
||||
const { datasetName } = body;
|
||||
const datasetFolder = path.join(datasetsPath, datasetName);
|
||||
|
||||
try {
|
||||
// Check if folder exists
|
||||
if (!fs.existsSync(datasetFolder)) {
|
||||
return NextResponse.json(
|
||||
{ error: `Folder '${datasetName}' not found` },
|
||||
{ status: 404 }
|
||||
);
|
||||
}
|
||||
|
||||
// Find all images recursively
|
||||
const imageFiles = findImagesRecursively(datasetFolder);
|
||||
|
||||
// Format response
|
||||
const result = imageFiles.map(imgPath => ({
|
||||
img_path: imgPath
|
||||
}));
|
||||
|
||||
return NextResponse.json({ images: result });
|
||||
} catch (error) {
|
||||
console.error('Error finding images:', error);
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to process request' },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively finds all image files in a directory and its subdirectories
|
||||
* @param dir Directory to search
|
||||
* @returns Array of absolute paths to image files
|
||||
*/
|
||||
function findImagesRecursively(dir: string): string[] {
|
||||
const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp'];
|
||||
let results: string[] = [];
|
||||
|
||||
const items = fs.readdirSync(dir);
|
||||
|
||||
for (const item of items) {
|
||||
const itemPath = path.join(dir, item);
|
||||
const stat = fs.statSync(itemPath);
|
||||
|
||||
if (stat.isDirectory()) {
|
||||
// If it's a directory, recursively search it
|
||||
results = results.concat(findImagesRecursively(itemPath));
|
||||
} else {
|
||||
// If it's a file, check if it's an image
|
||||
const ext = path.extname(itemPath).toLowerCase();
|
||||
if (imageExtensions.includes(ext)) {
|
||||
results.push(itemPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
55
ui/src/app/api/datasets/upload/route.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
// src/app/api/datasets/upload/route.ts
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { writeFile, mkdir } from 'fs/promises';
|
||||
import { join } from 'path';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const datasetsPath = await getDatasetsRoot();
|
||||
if (!datasetsPath) {
|
||||
return NextResponse.json({ error: 'Datasets path not found' }, { status: 500 });
|
||||
}
|
||||
const formData = await request.formData();
|
||||
const files = formData.getAll('files');
|
||||
const datasetName = formData.get('datasetName') as string;
|
||||
|
||||
if (!files || files.length === 0) {
|
||||
return NextResponse.json({ error: 'No files provided' }, { status: 400 });
|
||||
}
|
||||
|
||||
// Create upload directory if it doesn't exist
|
||||
const uploadDir = join(datasetsPath, datasetName);
|
||||
await mkdir(uploadDir, { recursive: true });
|
||||
|
||||
const savedFiles = await Promise.all(
|
||||
files.map(async (file: any) => {
|
||||
const bytes = await file.arrayBuffer();
|
||||
const buffer = Buffer.from(bytes);
|
||||
|
||||
// Clean filename and ensure it's unique
|
||||
const fileName = file.name.replace(/[^a-zA-Z0-9.-]/g, '_');
|
||||
const filePath = join(uploadDir, fileName);
|
||||
|
||||
await writeFile(filePath, buffer);
|
||||
return fileName;
|
||||
}),
|
||||
);
|
||||
|
||||
return NextResponse.json({
|
||||
message: 'Files uploaded successfully',
|
||||
files: savedFiles,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Upload error:', error);
|
||||
return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
// Increase payload size limit (default is 4mb)
|
||||
export const config = {
|
||||
api: {
|
||||
bodyParser: false,
|
||||
responseLimit: '50mb',
|
||||
},
|
||||
};
|
||||
106
ui/src/app/api/files/[...filePath]/route.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
/* eslint-disable */
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) {
|
||||
const { filePath } = await params;
|
||||
try {
|
||||
// Decode the path
|
||||
const decodedFilePath = decodeURIComponent(filePath);
|
||||
|
||||
// Get allowed directories
|
||||
const datasetRoot = await getDatasetsRoot();
|
||||
const trainingRoot = await getTrainingFolder();
|
||||
const allowedDirs = [datasetRoot, trainingRoot];
|
||||
|
||||
// Security check: Ensure path is in allowed directory
|
||||
const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
|
||||
|
||||
if (!isAllowed) {
|
||||
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
|
||||
return new NextResponse('Access denied', { status: 403 });
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!fs.existsSync(decodedFilePath)) {
|
||||
console.warn(`File not found: ${decodedFilePath}`);
|
||||
return new NextResponse('File not found', { status: 404 });
|
||||
}
|
||||
|
||||
// Get file info
|
||||
const stat = fs.statSync(decodedFilePath);
|
||||
if (!stat.isFile()) {
|
||||
return new NextResponse('Not a file', { status: 400 });
|
||||
}
|
||||
|
||||
// Get filename for Content-Disposition
|
||||
const filename = path.basename(decodedFilePath);
|
||||
|
||||
// Determine content type
|
||||
const ext = path.extname(decodedFilePath).toLowerCase();
|
||||
const contentTypeMap: { [key: string]: string } = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp',
|
||||
'.svg': 'image/svg+xml',
|
||||
'.bmp': 'image/bmp',
|
||||
'.safetensors': 'application/octet-stream',
|
||||
};
|
||||
|
||||
const contentType = contentTypeMap[ext] || 'application/octet-stream';
|
||||
|
||||
// Get range header for partial content support
|
||||
const range = request.headers.get('range');
|
||||
|
||||
// Common headers for better download handling
|
||||
const commonHeaders = {
|
||||
'Content-Type': contentType,
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Cache-Control': 'public, max-age=86400',
|
||||
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
|
||||
'X-Content-Type-Options': 'nosniff'
|
||||
};
|
||||
|
||||
if (range) {
|
||||
// Parse range header
|
||||
const parts = range.replace(/bytes=/, '').split('-');
|
||||
const start = parseInt(parts[0], 10);
|
||||
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
|
||||
const chunkSize = (end - start) + 1;
|
||||
|
||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||
start,
|
||||
end,
|
||||
highWaterMark: 64 * 1024 // 64KB buffer
|
||||
});
|
||||
|
||||
return new NextResponse(fileStream as any, {
|
||||
status: 206,
|
||||
headers: {
|
||||
...commonHeaders,
|
||||
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
|
||||
'Content-Length': String(chunkSize)
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// For full file download, read directly without streaming wrapper
|
||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||
highWaterMark: 64 * 1024 // 64KB buffer
|
||||
});
|
||||
|
||||
return new NextResponse(fileStream as any, {
|
||||
headers: {
|
||||
...commonHeaders,
|
||||
'Content-Length': String(stat.size)
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error serving file:', error);
|
||||
return new NextResponse('Internal Server Error', { status: 500 });
|
||||
}
|
||||
}
|
||||
106
ui/src/app/api/gpu/route.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { exec } from 'child_process';
|
||||
import { promisify } from 'util';
|
||||
|
||||
const execAsync = promisify(exec);
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
// Check if nvidia-smi is available
|
||||
const hasNvidiaSmi = await checkNvidiaSmi();
|
||||
|
||||
if (!hasNvidiaSmi) {
|
||||
return NextResponse.json({
|
||||
hasNvidiaSmi: false,
|
||||
gpus: [],
|
||||
error: 'nvidia-smi not found or not accessible',
|
||||
});
|
||||
}
|
||||
|
||||
// Get GPU stats
|
||||
const gpuStats = await getGpuStats();
|
||||
|
||||
return NextResponse.json({
|
||||
hasNvidiaSmi: true,
|
||||
gpus: gpuStats,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error fetching NVIDIA GPU stats:', error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
hasNvidiaSmi: false,
|
||||
gpus: [],
|
||||
error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`,
|
||||
},
|
||||
{ status: 500 },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function checkNvidiaSmi(): Promise<boolean> {
|
||||
try {
|
||||
await execAsync('which nvidia-smi');
|
||||
return true;
|
||||
} catch (error) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function getGpuStats() {
|
||||
// Get detailed GPU information in JSON format including fan speed
|
||||
const { stdout } = await execAsync(
|
||||
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits',
|
||||
);
|
||||
|
||||
// Parse CSV output
|
||||
const gpus = stdout
|
||||
.trim()
|
||||
.split('\n')
|
||||
.map(line => {
|
||||
const [
|
||||
index,
|
||||
name,
|
||||
driverVersion,
|
||||
temperature,
|
||||
gpuUtil,
|
||||
memoryUtil,
|
||||
memoryTotal,
|
||||
memoryFree,
|
||||
memoryUsed,
|
||||
powerDraw,
|
||||
powerLimit,
|
||||
clockGraphics,
|
||||
clockMemory,
|
||||
fanSpeed,
|
||||
] = line.split(', ').map(item => item.trim());
|
||||
|
||||
return {
|
||||
index: parseInt(index),
|
||||
name,
|
||||
driverVersion,
|
||||
temperature: parseInt(temperature),
|
||||
utilization: {
|
||||
gpu: parseInt(gpuUtil),
|
||||
memory: parseInt(memoryUtil),
|
||||
},
|
||||
memory: {
|
||||
total: parseInt(memoryTotal),
|
||||
free: parseInt(memoryFree),
|
||||
used: parseInt(memoryUsed),
|
||||
},
|
||||
power: {
|
||||
draw: parseFloat(powerDraw),
|
||||
limit: parseFloat(powerLimit),
|
||||
},
|
||||
clocks: {
|
||||
graphics: parseInt(clockGraphics),
|
||||
memory: parseInt(clockMemory),
|
||||
},
|
||||
fan: {
|
||||
speed: parseInt(fanSpeed), // Fan speed as percentage
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
return gpus;
|
||||
}
|
||||
68
ui/src/app/api/img/[...imagePath]/route.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
/* eslint-disable */
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
|
||||
const { imagePath } = await params;
|
||||
try {
|
||||
// Decode the path
|
||||
const filepath = decodeURIComponent(imagePath);
|
||||
|
||||
// Get allowed directories
|
||||
const datasetRoot = await getDatasetsRoot();
|
||||
const trainingRoot = await getTrainingFolder();
|
||||
|
||||
const allowedDirs = [datasetRoot, trainingRoot];
|
||||
|
||||
// Security check: Ensure path is in allowed directory
|
||||
const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..');
|
||||
|
||||
if (!isAllowed) {
|
||||
console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`);
|
||||
return new NextResponse('Access denied', { status: 403 });
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!fs.existsSync(filepath)) {
|
||||
console.warn(`File not found: ${filepath}`);
|
||||
return new NextResponse('File not found', { status: 404 });
|
||||
}
|
||||
|
||||
// Get file info
|
||||
const stat = fs.statSync(filepath);
|
||||
if (!stat.isFile()) {
|
||||
return new NextResponse('Not a file', { status: 400 });
|
||||
}
|
||||
|
||||
// Determine content type
|
||||
const ext = path.extname(filepath).toLowerCase();
|
||||
const contentTypeMap: { [key: string]: string } = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp',
|
||||
'.svg': 'image/svg+xml',
|
||||
'.bmp': 'image/bmp',
|
||||
};
|
||||
|
||||
const contentType = contentTypeMap[ext] || 'application/octet-stream';
|
||||
|
||||
// Read file as buffer
|
||||
const fileBuffer = fs.readFileSync(filepath);
|
||||
|
||||
// Return file with appropriate headers
|
||||
return new NextResponse(fileBuffer, {
|
||||
headers: {
|
||||
'Content-Type': contentType,
|
||||
'Content-Length': String(stat.size),
|
||||
'Cache-Control': 'public, max-age=86400',
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error serving image:', error);
|
||||
return new NextResponse('Internal Server Error', { status: 500 });
|
||||
}
|
||||
}
|
||||
30
ui/src/app/api/img/caption/route.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { imgPath, caption } = body;
|
||||
let datasetsPath = await getDatasetsRoot();
|
||||
// make sure the dataset path is in the image path
|
||||
if (!imgPath.startsWith(datasetsPath)) {
|
||||
return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
|
||||
}
|
||||
|
||||
// if img doesnt exist, ignore
|
||||
if (!fs.existsSync(imgPath)) {
|
||||
return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
|
||||
}
|
||||
|
||||
|
||||
// check for caption
|
||||
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
|
||||
// save caption to file
|
||||
fs.writeFileSync(captionPath, caption);
|
||||
|
||||
return NextResponse.json({ success: true });
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
34
ui/src/app/api/img/delete/route.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { imgPath } = body;
|
||||
let datasetsPath = await getDatasetsRoot();
|
||||
// make sure the dataset path is in the image path
|
||||
if (!imgPath.startsWith(datasetsPath)) {
|
||||
return NextResponse.json({ error: 'Invalid image path' }, { status: 400 });
|
||||
}
|
||||
|
||||
// if img doesnt exist, ignore
|
||||
if (!fs.existsSync(imgPath)) {
|
||||
return NextResponse.json({ success: true });
|
||||
}
|
||||
|
||||
// delete it and return success
|
||||
fs.unlinkSync(imgPath);
|
||||
|
||||
// check for caption
|
||||
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
|
||||
if (fs.existsSync(captionPath)) {
|
||||
// delete caption file
|
||||
fs.unlinkSync(captionPath);
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true });
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
32
ui/src/app/api/jobs/[jobID]/delete/route.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { getTrainingFolder } from '@/server/settings';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
if (!job) {
|
||||
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
const trainingRoot = await getTrainingFolder();
|
||||
const trainingFolder = path.join(trainingRoot, job.name);
|
||||
|
||||
if (fs.existsSync(trainingFolder)) {
|
||||
fs.rmdirSync(trainingFolder, { recursive: true });
|
||||
}
|
||||
|
||||
await prisma.job.delete({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
48
ui/src/app/api/jobs/[jobID]/files/route.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import { getTrainingFolder } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
if (!job) {
|
||||
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
const trainingFolder = await getTrainingFolder();
|
||||
const jobFolder = path.join(trainingFolder, job.name);
|
||||
|
||||
if (!fs.existsSync(jobFolder)) {
|
||||
return NextResponse.json({ files: [] });
|
||||
}
|
||||
|
||||
// find all img (png, jpg, jpeg) files in the samples folder
|
||||
let files = fs
|
||||
.readdirSync(jobFolder)
|
||||
.filter(file => {
|
||||
return file.endsWith('.safetensors');
|
||||
})
|
||||
.map(file => {
|
||||
return path.join(jobFolder, file);
|
||||
})
|
||||
.sort();
|
||||
|
||||
// get the file size for each file
|
||||
const fileObjects = files.map(file => {
|
||||
const stats = fs.statSync(file);
|
||||
return {
|
||||
path: file,
|
||||
size: stats.size,
|
||||
};
|
||||
});
|
||||
|
||||
return NextResponse.json({ files: fileObjects });
|
||||
}
|
||||
40
ui/src/app/api/jobs/[jobID]/samples/route.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import { getTrainingFolder } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
if (!job) {
|
||||
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
// setup the training
|
||||
const trainingFolder = await getTrainingFolder();
|
||||
|
||||
const samplesFolder = path.join(trainingFolder, job.name, 'samples');
|
||||
if (!fs.existsSync(samplesFolder)) {
|
||||
return NextResponse.json({ samples: [] });
|
||||
}
|
||||
|
||||
// find all img (png, jpg, jpeg) files in the samples folder
|
||||
const samples = fs
|
||||
.readdirSync(samplesFolder)
|
||||
.filter(file => {
|
||||
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg');
|
||||
})
|
||||
.map(file => {
|
||||
return path.join(samplesFolder, file);
|
||||
})
|
||||
.sort();
|
||||
|
||||
return NextResponse.json({ samples });
|
||||
}
|
||||
93
ui/src/app/api/jobs/[jobID]/start/route.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
|
||||
import { spawn } from 'child_process';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import { getTrainingFolder, getHFToken } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
if (!job) {
|
||||
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
// update job status to 'running'
|
||||
await prisma.job.update({
|
||||
where: { id: jobID },
|
||||
data: {
|
||||
status: 'running',
|
||||
stop: false,
|
||||
info: 'Starting job...',
|
||||
},
|
||||
});
|
||||
|
||||
// setup the training
|
||||
|
||||
const trainingRoot = await getTrainingFolder();
|
||||
|
||||
const trainingFolder = path.join(trainingRoot, job.name);
|
||||
if (!fs.existsSync(trainingFolder)) {
|
||||
fs.mkdirSync(trainingFolder, { recursive: true });
|
||||
}
|
||||
|
||||
// make the config file
|
||||
const configPath = path.join(trainingFolder, '.job_config.json');
|
||||
|
||||
// update the config dataset path
|
||||
const jobConfig = JSON.parse(job.job_config);
|
||||
jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
|
||||
|
||||
// write the config file
|
||||
fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
|
||||
|
||||
let pythonPath = 'python';
|
||||
// use .venv or venv if it exists
|
||||
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
||||
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
||||
}
|
||||
|
||||
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
|
||||
if (!fs.existsSync(runFilePath)) {
|
||||
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
|
||||
}
|
||||
const additionalEnv: any = {
|
||||
AITK_JOB_ID: jobID,
|
||||
CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`,
|
||||
};
|
||||
|
||||
// HF_TOKEN
|
||||
const hfToken = await getHFToken();
|
||||
if (hfToken && hfToken.trim() !== '') {
|
||||
additionalEnv.HF_TOKEN = hfToken;
|
||||
}
|
||||
|
||||
// console.log(
|
||||
// 'Spawning command:',
|
||||
// `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`,
|
||||
// );
|
||||
|
||||
// start job
|
||||
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||
detached: true,
|
||||
stdio: 'ignore',
|
||||
env: {
|
||||
...process.env,
|
||||
...additionalEnv,
|
||||
},
|
||||
cwd: TOOLKIT_ROOT,
|
||||
});
|
||||
|
||||
subprocess.unref();
|
||||
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
23
ui/src/app/api/jobs/[jobID]/stop/route.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
// update job status to 'running'
|
||||
await prisma.job.update({
|
||||
where: { id: jobID },
|
||||
data: {
|
||||
stop: true,
|
||||
info: 'Stopping job...',
|
||||
},
|
||||
});
|
||||
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
58
ui/src/app/api/jobs/route.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: Request) {
|
||||
const { searchParams } = new URL(request.url);
|
||||
const id = searchParams.get('id');
|
||||
|
||||
try {
|
||||
if (id) {
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
|
||||
const jobs = await prisma.job.findMany({
|
||||
orderBy: { created_at: 'desc' },
|
||||
});
|
||||
return NextResponse.json({ jobs: jobs });
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { id, name, job_config, gpu_ids } = body;
|
||||
|
||||
if (id) {
|
||||
// Update existing training
|
||||
const training = await prisma.job.update({
|
||||
where: { id },
|
||||
data: {
|
||||
name,
|
||||
gpu_ids,
|
||||
job_config: JSON.stringify(job_config),
|
||||
},
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
} else {
|
||||
// Create new training
|
||||
const training = await prisma.job.create({
|
||||
data: {
|
||||
name,
|
||||
gpu_ids,
|
||||
job_config: JSON.stringify(job_config),
|
||||
},
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
}
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
59
ui/src/app/api/settings/route.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
|
||||
import {flushCache} from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
const settings = await prisma.settings.findMany();
|
||||
const settingsObject = settings.reduce((acc: any, setting) => {
|
||||
acc[setting.key] = setting.value;
|
||||
return acc;
|
||||
}, {});
|
||||
// if TRAINING_FOLDER is not set, use default
|
||||
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
|
||||
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
|
||||
}
|
||||
// if DATASETS_FOLDER is not set, use default
|
||||
if (!settingsObject.DATASETS_FOLDER || settingsObject.DATASETS_FOLDER === '') {
|
||||
settingsObject.DATASETS_FOLDER = defaultDatasetsFolder;
|
||||
}
|
||||
return NextResponse.json(settingsObject);
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { HF_TOKEN, TRAINING_FOLDER, DATASETS_FOLDER } = body;
|
||||
|
||||
// Upsert both settings
|
||||
await Promise.all([
|
||||
prisma.settings.upsert({
|
||||
where: { key: 'HF_TOKEN' },
|
||||
update: { value: HF_TOKEN },
|
||||
create: { key: 'HF_TOKEN', value: HF_TOKEN },
|
||||
}),
|
||||
prisma.settings.upsert({
|
||||
where: { key: 'TRAINING_FOLDER' },
|
||||
update: { value: TRAINING_FOLDER },
|
||||
create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER },
|
||||
}),
|
||||
prisma.settings.upsert({
|
||||
where: { key: 'DATASETS_FOLDER' },
|
||||
update: { value: DATASETS_FOLDER },
|
||||
create: { key: 'DATASETS_FOLDER', value: DATASETS_FOLDER },
|
||||
}),
|
||||
]);
|
||||
|
||||
flushCache();
|
||||
|
||||
return NextResponse.json({ success: true });
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
BIN
ui/src/app/apple-icon.png
Normal file
|
After Width: | Height: | Size: 8.7 KiB |
31
ui/src/app/dashboard/page.tsx
Normal file
@@ -0,0 +1,31 @@
|
||||
'use client';
|
||||
|
||||
import GpuMonitor from '@/components/GPUMonitor';
|
||||
import JobsTable from '@/components/JobsTable';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import Link from 'next/link';
|
||||
|
||||
export default function Dashboard() {
|
||||
return (
|
||||
<>
|
||||
<TopBar>
|
||||
<div>
|
||||
<h1 className="text-lg">Dashboard</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
<GpuMonitor />
|
||||
<div className="w-full mt-4">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
<h1 className="text-md">Active Jobs</h1>
|
||||
<div className="text-xs text-gray-500">
|
||||
<Link href="/jobs">View All</Link>
|
||||
</div>
|
||||
</div>
|
||||
<JobsTable onlyActive />
|
||||
</div>
|
||||
</MainContent>
|
||||
</>
|
||||
);
|
||||
}
|
||||
86
ui/src/app/datasets/[datasetName]/page.tsx
Normal file
@@ -0,0 +1,86 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState, use } from 'react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
import DatasetImageCard from '@/components/DatasetImageCard';
|
||||
import { Button } from '@headlessui/react';
|
||||
import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
|
||||
export default function DatasetPage({ params }: { params: { datasetName: string } }) {
|
||||
const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
|
||||
const usableParams = use(params as any) as { datasetName: string };
|
||||
const datasetName = usableParams.datasetName;
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshImageList = (dbName: string) => {
|
||||
setStatus('loading');
|
||||
fetch('/api/datasets/listImages', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ datasetName: dbName }),
|
||||
})
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Images:', data.images);
|
||||
// sort
|
||||
data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));
|
||||
setImgList(data.images);
|
||||
setStatus('success');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching images:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
useEffect(() => {
|
||||
if (datasetName) {
|
||||
refreshImageList(datasetName);
|
||||
}
|
||||
}, [datasetName]);
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Fixed top bar */}
|
||||
<TopBar>
|
||||
<div>
|
||||
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
|
||||
<FaChevronLeft />
|
||||
</Button>
|
||||
</div>
|
||||
<div>
|
||||
<h1 className="text-lg">Dataset: {datasetName}</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
<div>
|
||||
<Button
|
||||
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
|
||||
onClick={() => openImagesModal(datasetName, () => refreshImageList(datasetName))}
|
||||
>
|
||||
Add Images
|
||||
</Button>
|
||||
</div>
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
{status === 'loading' && <p>Loading...</p>}
|
||||
{status === 'error' && <p>Error fetching images</p>}
|
||||
{status === 'success' && (
|
||||
<div className="grid grid-cols-1 sm:grid-cols-2 md:grid-cols-3 lg:grid-cols-4 gap-4">
|
||||
{imgList.length === 0 && <p>No images found</p>}
|
||||
{imgList.map(img => (
|
||||
<DatasetImageCard
|
||||
key={img.img_path}
|
||||
alt="image"
|
||||
imageUrl={img.img_path}
|
||||
onDelete={() => refreshImageList(datasetName)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</MainContent>
|
||||
<AddImagesModal />
|
||||
</>
|
||||
);
|
||||
}
|
||||
157
ui/src/app/datasets/page.tsx
Normal file
@@ -0,0 +1,157 @@
|
||||
'use client';
|
||||
|
||||
import { useState } from 'react';
|
||||
import { Modal } from '@/components/Modal';
|
||||
import Link from 'next/link';
|
||||
import { TextInput } from '@/components/formInputs';
|
||||
import useDatasetList from '@/hooks/useDatasetList';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { FaRegTrashAlt } from 'react-icons/fa';
|
||||
import { openConfirm } from '@/components/ConfirmModal';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
|
||||
|
||||
export default function Datasets() {
|
||||
const { datasets, status, refreshDatasets } = useDatasetList();
|
||||
const [newDatasetName, setNewDatasetName] = useState('');
|
||||
const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false);
|
||||
|
||||
// Transform datasets array into rows with objects
|
||||
const tableRows = datasets.map(dataset => ({
|
||||
name: dataset,
|
||||
actions: dataset, // Pass full dataset name for actions
|
||||
}));
|
||||
|
||||
const columns: TableColumn[] = [
|
||||
{
|
||||
title: 'Dataset Name',
|
||||
key: 'name',
|
||||
render: row => (
|
||||
<Link href={`/datasets/${row.name}`} className="text-gray-200 hover:text-gray-100">
|
||||
{row.name}
|
||||
</Link>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: 'Actions',
|
||||
key: 'actions',
|
||||
className: 'w-20 text-right',
|
||||
render: row => (
|
||||
<button
|
||||
className="text-gray-200 hover:bg-red-600 p-2 rounded-full transition-colors"
|
||||
onClick={() => handleDeleteDataset(row.name)}
|
||||
>
|
||||
<FaRegTrashAlt />
|
||||
</button>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
const handleDeleteDataset = (datasetName: string) => {
|
||||
openConfirm({
|
||||
title: 'Delete Dataset',
|
||||
message: `Are you sure you want to delete the dataset "${datasetName}"? This action cannot be undone.`,
|
||||
type: 'warning',
|
||||
confirmText: 'Delete',
|
||||
onConfirm: () => {
|
||||
fetch('/api/datasets/delete', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ name: datasetName }),
|
||||
})
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Dataset deleted:', data);
|
||||
refreshDatasets();
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error deleting dataset:', error);
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const handleCreateDataset = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
try {
|
||||
const response = await fetch('/api/datasets/create', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ name: newDatasetName }),
|
||||
});
|
||||
const data = await response.json();
|
||||
console.log('New dataset created:', data);
|
||||
refreshDatasets();
|
||||
setNewDatasetName('');
|
||||
setIsNewDatasetModalOpen(false);
|
||||
} catch (error) {
|
||||
console.error('Error creating new dataset:', error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<TopBar>
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-gray-100">Datasets</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
<div>
|
||||
<Button
|
||||
className="text-gray-200 bg-slate-600 px-4 py-2 rounded-md hover:bg-slate-500 transition-colors"
|
||||
onClick={() => setIsNewDatasetModalOpen(true)}
|
||||
>
|
||||
New Dataset
|
||||
</Button>
|
||||
</div>
|
||||
</TopBar>
|
||||
|
||||
<MainContent>
|
||||
<UniversalTable
|
||||
columns={columns}
|
||||
rows={tableRows}
|
||||
isLoading={status === 'loading'}
|
||||
onRefresh={refreshDatasets}
|
||||
/>
|
||||
</MainContent>
|
||||
|
||||
<Modal
|
||||
isOpen={isNewDatasetModalOpen}
|
||||
onClose={() => setIsNewDatasetModalOpen(false)}
|
||||
title="New Dataset"
|
||||
size="md"
|
||||
>
|
||||
<div className="space-y-4 text-gray-200">
|
||||
<form onSubmit={handleCreateDataset}>
|
||||
<div className="text-sm text-gray-400">
|
||||
This will create a new folder with the name below in your dataset folder.
|
||||
</div>
|
||||
<div className="mt-4">
|
||||
<TextInput label="Dataset Name" value={newDatasetName} onChange={value => setNewDatasetName(value)} />
|
||||
</div>
|
||||
|
||||
<div className="mt-6 flex justify-end space-x-3">
|
||||
<button
|
||||
type="button"
|
||||
className="rounded-md bg-gray-700 px-4 py-2 text-gray-200 hover:bg-gray-600 focus:outline-none focus:ring-2 focus:ring-gray-500"
|
||||
onClick={() => setIsNewDatasetModalOpen(false)}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
type="submit"
|
||||
className="rounded-md bg-blue-600 px-4 py-2 text-white hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
>
|
||||
Confirm
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
}
|
||||
BIN
ui/src/app/favicon.ico
Normal file
|
After Width: | Height: | Size: 15 KiB |
21
ui/src/app/globals.css
Normal file
@@ -0,0 +1,21 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
:root {
|
||||
--background: #ffffff;
|
||||
--foreground: #171717;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root {
|
||||
--background: #0a0a0a;
|
||||
--foreground: #ededed;
|
||||
}
|
||||
}
|
||||
|
||||
body {
|
||||
color: var(--foreground);
|
||||
background: var(--background);
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
}
|
||||
BIN
ui/src/app/icon.png
Normal file
|
After Width: | Height: | Size: 4.3 KiB |
3
ui/src/app/icon.svg
Normal file
|
After Width: | Height: | Size: 110 KiB |
80
ui/src/app/jobs/[jobID]/page.tsx
Normal file
@@ -0,0 +1,80 @@
|
||||
'use client';
|
||||
|
||||
import { useMemo, useState, use } from 'react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import useJob from '@/hooks/useJob';
|
||||
import { startJob, stopJob } from '@/utils/jobs';
|
||||
import SampleImages from '@/components/SampleImages';
|
||||
import JobOverview from '@/components/JobOverview';
|
||||
import { JobConfig } from '@/types';
|
||||
import { redirect } from 'next/navigation';
|
||||
import JobActionBar from '@/components/JobActionBar';
|
||||
|
||||
type PageKey = 'overview' | 'samples';
|
||||
|
||||
interface Page {
|
||||
name: string;
|
||||
value: PageKey;
|
||||
}
|
||||
|
||||
const pages: Page[] = [
|
||||
{ name: 'Overview', value: 'overview' },
|
||||
{ name: 'Samples', value: 'samples' },
|
||||
];
|
||||
|
||||
export default function JobPage({ params }: { params: { jobID: string } }) {
|
||||
const usableParams = use(params as any) as { jobID: string };
|
||||
const jobID = usableParams.jobID;
|
||||
const { job, status, refreshJob } = useJob(jobID, 5000);
|
||||
const [pageKey, setPageKey] = useState<PageKey>('overview');
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Fixed top bar */}
|
||||
<TopBar>
|
||||
<div>
|
||||
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => redirect('/jobs')}>
|
||||
<FaChevronLeft />
|
||||
</Button>
|
||||
</div>
|
||||
<div>
|
||||
<h1 className="text-lg">Job: {job?.name}</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
{job && (
|
||||
<JobActionBar
|
||||
job={job}
|
||||
onRefresh={refreshJob}
|
||||
hideView
|
||||
afterDelete={() => {
|
||||
redirect('/jobs');
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</TopBar>
|
||||
<MainContent className="pt-24">
|
||||
{status === 'loading' && job == null && <p>Loading...</p>}
|
||||
{status === 'error' && job == null && <p>Error fetching job</p>}
|
||||
{job && (
|
||||
<>
|
||||
{pageKey === 'overview' && <JobOverview job={job} />}
|
||||
{pageKey === 'samples' && <SampleImages job={job} />}
|
||||
</>
|
||||
)}
|
||||
</MainContent>
|
||||
<div className="bg-gray-800 absolute top-12 left-0 w-full h-8 flex items-center px-2 text-sm">
|
||||
{pages.map(page => (
|
||||
<Button
|
||||
key={page.value}
|
||||
onClick={() => setPageKey(page.value)}
|
||||
className={`px-4 py-1 h-8 ${page.value === pageKey ? 'bg-gray-300 dark:bg-gray-700' : ''}`}
|
||||
>
|
||||
{page.name}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
101
ui/src/app/jobs/new/jobConfig.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import { JobConfig, DatasetConfig } from '@/types';
|
||||
|
||||
export const defaultDatasetConfig: DatasetConfig = {
|
||||
folder_path: '/path/to/images/folder',
|
||||
mask_path: null,
|
||||
mask_min_value: 0.1,
|
||||
default_caption: '',
|
||||
caption_ext: 'txt',
|
||||
caption_dropout_rate: 0.05,
|
||||
cache_latents_to_disk: false,
|
||||
is_reg: false,
|
||||
network_weight: 1,
|
||||
resolution: [512, 768, 1024],
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
job: 'extension',
|
||||
config: {
|
||||
name: 'my_first_flex_lora_v1',
|
||||
process: [
|
||||
{
|
||||
type: 'ui_trainer',
|
||||
training_folder: 'output',
|
||||
sqlite_db_path: './aitk_db.db',
|
||||
device: 'cuda:0',
|
||||
trigger_word: null,
|
||||
performance_log_every: 10,
|
||||
network: {
|
||||
type: 'lora',
|
||||
linear: 16,
|
||||
linear_alpha: 16,
|
||||
},
|
||||
save: {
|
||||
dtype: 'bf16',
|
||||
save_every: 250,
|
||||
max_step_saves_to_keep: 4,
|
||||
save_format: 'diffusers',
|
||||
push_to_hub: false,
|
||||
},
|
||||
datasets: [
|
||||
defaultDatasetConfig
|
||||
],
|
||||
train: {
|
||||
batch_size: 1,
|
||||
bypass_guidance_embedding: true,
|
||||
steps: 2000,
|
||||
gradient_accumulation: 1,
|
||||
train_unet: true,
|
||||
train_text_encoder: false,
|
||||
gradient_checkpointing: true,
|
||||
noise_scheduler: 'flowmatch',
|
||||
optimizer: 'adamw8bit',
|
||||
timestep_type: 'sigmoid',
|
||||
content_or_style: 'balanced',
|
||||
optimizer_params: {
|
||||
weight_decay: 1e-4
|
||||
},
|
||||
lr: 0.0001,
|
||||
ema_config: {
|
||||
use_ema: true,
|
||||
ema_decay: 0.99,
|
||||
},
|
||||
dtype: 'bf16',
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
is_flux: true,
|
||||
quantize: true,
|
||||
quantize_te: true
|
||||
},
|
||||
sample: {
|
||||
sampler: 'flowmatch',
|
||||
sample_every: 250,
|
||||
width: 1024,
|
||||
height: 1024,
|
||||
prompts: [
|
||||
'woman with red hair, playing chess at the park, bomb going off in the background',
|
||||
'a woman holding a coffee cup, in a beanie, sitting at a cafe',
|
||||
'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
|
||||
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
|
||||
'a bear building a log cabin in the snow covered mountains',
|
||||
'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',
|
||||
'hipster man with a beard, building a chair, in a wood shop',
|
||||
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
|
||||
"a man holding a sign that says, 'this is a sign'",
|
||||
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
|
||||
],
|
||||
neg: '',
|
||||
seed: 42,
|
||||
walk_seed: true,
|
||||
guidance_scale: 4,
|
||||
sample_steps: 25,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
meta: {
|
||||
name: '[name]',
|
||||
version: '1.0',
|
||||
},
|
||||
};
|
||||
41
ui/src/app/jobs/new/options.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
export interface Model {
|
||||
name_or_path: string;
|
||||
defaults?: { [key: string]: any };
|
||||
}
|
||||
|
||||
export interface Option {
|
||||
model: Model[];
|
||||
}
|
||||
|
||||
export const options = {
|
||||
model: [
|
||||
{
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_flux': [true, false],
|
||||
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'black-forest-labs/FLUX.1-dev',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_flux': [true, false],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'Alpha-VLLM/Lumina-Image-2.0',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [false, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_lumina2': [true, false],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as Option;
|
||||
618
ui/src/app/jobs/new/page.tsx
Normal file
@@ -0,0 +1,618 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useSearchParams, useRouter } from 'next/navigation';
|
||||
import { options } from './options';
|
||||
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
import { useNestedState } from '@/utils/hooks';
|
||||
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
||||
import Card from '@/components/Card';
|
||||
import { X } from 'lucide-react';
|
||||
import useSettings from '@/hooks/useSettings';
|
||||
import useGPUInfo from '@/hooks/useGPUInfo';
|
||||
import useDatasetList from '@/hooks/useDatasetList';
|
||||
import path from 'path';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
|
||||
export default function TrainingForm() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const runId = searchParams.get('id');
|
||||
const [gpuIDs, setGpuIDs] = useState<string | null>(null);
|
||||
const { settings, isSettingsLoaded } = useSettings();
|
||||
const { gpuList, isGPUInfoLoaded } = useGPUInfo();
|
||||
const { datasets, status: datasetFetchStatus } = useDatasetList();
|
||||
const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]);
|
||||
|
||||
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
|
||||
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
||||
|
||||
useEffect(() => {
|
||||
if (!isSettingsLoaded) return;
|
||||
if (datasetFetchStatus !== 'success') return;
|
||||
|
||||
const datasetOptions = datasets.map(name => ({ value: path.join(settings.DATASETS_FOLDER, name), label: name }));
|
||||
setDatasetOptions(datasetOptions);
|
||||
const defaultDatasetPath = defaultDatasetConfig.folder_path;
|
||||
|
||||
for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
|
||||
const dataset = jobConfig.config.process[0].datasets[i];
|
||||
if (dataset.folder_path === defaultDatasetPath) {
|
||||
setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`);
|
||||
}
|
||||
}
|
||||
}, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
|
||||
|
||||
useEffect(() => {
|
||||
if (runId) {
|
||||
fetch(`/api/jobs?id=${runId}`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
setGpuIDs(data.gpu_ids);
|
||||
setJobConfig(JSON.parse(data.job_config));
|
||||
// setJobConfig(data.name, 'config.name');
|
||||
})
|
||||
.catch(error => console.error('Error fetching training:', error));
|
||||
}
|
||||
}, [runId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isGPUInfoLoaded) {
|
||||
if (gpuIDs === null && gpuList.length > 0) {
|
||||
setGpuIDs(`${gpuList[0].index}`);
|
||||
}
|
||||
}
|
||||
}, [gpuList, isGPUInfoLoaded]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isSettingsLoaded) {
|
||||
setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder');
|
||||
}
|
||||
}, [settings, isSettingsLoaded]);
|
||||
|
||||
const saveJob = async () => {
|
||||
if (status === 'saving') return;
|
||||
setStatus('saving');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/jobs', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
id: runId,
|
||||
name: jobConfig.config.name,
|
||||
gpu_ids: gpuIDs,
|
||||
job_config: jobConfig,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error('Failed to save training');
|
||||
|
||||
setStatus('success');
|
||||
if (!runId) {
|
||||
const data = await response.json();
|
||||
router.push(`/jobs/${data.id}`);
|
||||
}
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
} catch (error) {
|
||||
console.error('Error saving training:', error);
|
||||
setStatus('error');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
saveJob();
|
||||
};
|
||||
|
||||
console.log('jobConfig.config.process[0].network.linear', jobConfig?.config?.process[0].network?.linear);
|
||||
|
||||
return (
|
||||
<>
|
||||
<TopBar>
|
||||
<div>
|
||||
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
|
||||
<FaChevronLeft />
|
||||
</Button>
|
||||
</div>
|
||||
<div>
|
||||
<h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
<div>
|
||||
<Button
|
||||
className="text-gray-200 bg-green-800 px-3 py-1 rounded-md"
|
||||
onClick={() => saveJob()}
|
||||
disabled={status === 'saving'}
|
||||
>
|
||||
{status === 'saving' ? 'Saving...' : runId ? 'Update Job' : 'Create Job'}
|
||||
</Button>
|
||||
</div>
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<Card title="Job Settings">
|
||||
<TextInput
|
||||
label="Training Name"
|
||||
value={jobConfig.config.name}
|
||||
onChange={value => setJobConfig(value, 'config.name')}
|
||||
placeholder="Enter training name"
|
||||
disabled={runId !== null}
|
||||
required
|
||||
/>
|
||||
<SelectInput
|
||||
label="GPU ID"
|
||||
value={`${gpuIDs}`}
|
||||
onChange={value => setGpuIDs(value)}
|
||||
options={gpuList.map(gpu => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
||||
/>
|
||||
<TextInput
|
||||
label="Trigger Word"
|
||||
value={jobConfig.config.process[0].trigger_word || ''}
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
}
|
||||
setJobConfig(value, 'jobConfig.config.process[0].trigger_word');
|
||||
}}
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
<Card title="Model Configuration">
|
||||
<SelectInput
|
||||
label="Name or Path"
|
||||
value={jobConfig.config.process[0].model.name_or_path}
|
||||
onChange={value => {
|
||||
// see if model changed
|
||||
const currentModel = options.model.find(
|
||||
model => model.name_or_path === jobConfig.config.process[0].model.name_or_path,
|
||||
);
|
||||
if (!currentModel || currentModel.name_or_path === value) {
|
||||
// model has not changed
|
||||
return;
|
||||
}
|
||||
// revert defaults from previous model
|
||||
for (const key in currentModel.defaults) {
|
||||
setJobConfig(currentModel.defaults[key][1], key);
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.name_or_path');
|
||||
// update the defaults when a model is selected
|
||||
const model = options.model.find(model => model.name_or_path === value);
|
||||
if (model?.defaults) {
|
||||
for (const key in model.defaults) {
|
||||
setJobConfig(model.defaults[key][0], key);
|
||||
}
|
||||
}
|
||||
}}
|
||||
options={options.model.map(model => ({
|
||||
value: model.name_or_path,
|
||||
label: model.name_or_path,
|
||||
}))}
|
||||
/>
|
||||
<FormGroup label="Quantize">
|
||||
<div className='grid grid-cols-2 gap-2'>
|
||||
<Checkbox
|
||||
label="Transformer"
|
||||
checked={jobConfig.config.process[0].model.quantize}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Text Encoder"
|
||||
checked={jobConfig.config.process[0].model.quantize_te}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
{jobConfig.config.process[0].network?.type && (
|
||||
<Card title="LoRA Configuration">
|
||||
<NumberInput
|
||||
label="Linear Rank"
|
||||
value={jobConfig.config.process[0].network.linear}
|
||||
onChange={value => {
|
||||
console.log('onChange', value);
|
||||
setJobConfig(value, 'config.process[0].network.linear');
|
||||
setJobConfig(value, 'config.process[0].network.linear_alpha');
|
||||
}}
|
||||
placeholder="eg. 16"
|
||||
min={0}
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Save Configuration">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
value={jobConfig.config.process[0].save.dtype}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.dtype')}
|
||||
options={[
|
||||
{ value: 'bf16', label: 'BF16' },
|
||||
{ value: 'fp16', label: 'FP16' },
|
||||
{ value: 'fp32', label: 'FP32' },
|
||||
]}
|
||||
/>
|
||||
<NumberInput
|
||||
label="Save Every"
|
||||
value={jobConfig.config.process[0].save.save_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.save_every')}
|
||||
placeholder="eg. 250"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Max Step Saves to Keep"
|
||||
value={jobConfig.config.process[0].save.max_step_saves_to_keep}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')}
|
||||
placeholder="eg. 4"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Training Configuration">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Batch Size"
|
||||
value={jobConfig.config.process[0].train.batch_size}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
|
||||
placeholder="eg. 4"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Gradient Accumulation"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.gradient_accumulation}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.gradient_accumulation')}
|
||||
placeholder="eg. 1"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Steps"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.steps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.steps')}
|
||||
placeholder="eg. 2000"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Optimizer"
|
||||
value={jobConfig.config.process[0].train.optimizer}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
|
||||
options={[
|
||||
{ value: 'adamw8bit', label: 'AdamW8Bit' },
|
||||
{ value: 'adafactor', label: 'Adafactor' },
|
||||
]}
|
||||
/>
|
||||
<NumberInput
|
||||
label="Learning Rate"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.lr}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.lr')}
|
||||
placeholder="eg. 0.0001"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Weight Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.optimizer_params.weight_decay}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')}
|
||||
placeholder="eg. 0.0001"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Timestep Type"
|
||||
value={jobConfig.config.process[0].train.timestep_type}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
||||
options={[
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
{ value: 'linear', label: 'Linear' },
|
||||
{ value: 'flux_shift', label: 'Flux Shift' },
|
||||
]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Timestep Bias"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.content_or_style}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.content_or_style')}
|
||||
options={[
|
||||
{ value: 'balanced', label: 'Balanced' },
|
||||
{ value: 'content', label: 'High Noise' },
|
||||
{ value: 'style', label: 'Low Noise' },
|
||||
]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Noise Scheduler"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.noise_scheduler}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
||||
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="EMA (Exponential Moving Average)">
|
||||
<Checkbox
|
||||
label="Use EMA"
|
||||
className='pt-1'
|
||||
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="EMA Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
|
||||
placeholder="eg. 0.99"
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Datasets">
|
||||
<>
|
||||
{jobConfig.config.process[0].datasets.map((dataset, i) => (
|
||||
<div key={i} className="p-4 rounded-lg bg-gray-800 relative">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
jobConfig.config.process[0].datasets.filter((_, index) => index !== i),
|
||||
'config.process[0].datasets',
|
||||
)
|
||||
}
|
||||
className="absolute top-2 right-2 bg-red-800 hover:bg-red-700 rounded-full p-1 text-sm transition-colors"
|
||||
>
|
||||
<X />
|
||||
</button>
|
||||
<h2 className="text-lg font-bold mb-4">Dataset {i + 1}</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Dataset"
|
||||
value={dataset.folder_path}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
||||
options={datasetOptions}
|
||||
/>
|
||||
<NumberInput
|
||||
label="LoRA Weight"
|
||||
value={dataset.network_weight}
|
||||
className="pt-2"
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)}
|
||||
placeholder="eg. 1.0"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<TextInput
|
||||
label="Default Caption"
|
||||
value={dataset.default_caption}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
|
||||
placeholder="eg. A photo of a cat"
|
||||
/>
|
||||
<NumberInput
|
||||
label="Caption Dropout Rate"
|
||||
className="pt-2"
|
||||
value={dataset.caption_dropout_rate}
|
||||
onChange={value =>
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)
|
||||
}
|
||||
placeholder="eg. 0.05"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Settings" className="">
|
||||
<Checkbox
|
||||
label="Cache Latents"
|
||||
checked={dataset.cache_latents_to_disk || false}
|
||||
onChange={value =>
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
|
||||
}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Is Regularization"
|
||||
checked={dataset.is_reg || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||
/>
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Resolutions" className="pt-2">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{[
|
||||
[256, 512, 768],
|
||||
[1024, 1280, 1536],
|
||||
].map(resGroup => (
|
||||
<div key={resGroup[0]} className="space-y-2">
|
||||
{resGroup.map(res => (
|
||||
<Checkbox
|
||||
key={res}
|
||||
label={res.toString()}
|
||||
checked={dataset.resolution.includes(res)}
|
||||
onChange={value => {
|
||||
const resolutions = dataset.resolution.includes(res)
|
||||
? dataset.resolution.filter(r => r !== res)
|
||||
: [...dataset.resolution, res];
|
||||
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</FormGroup>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
[...jobConfig.config.process[0].datasets, objectCopy(defaultDatasetConfig)],
|
||||
'config.process[0].datasets',
|
||||
)
|
||||
}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||
>
|
||||
Add Dataset
|
||||
</button>
|
||||
</>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Sample Configuration">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Sample Every"
|
||||
value={jobConfig.config.process[0].sample.sample_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_every')}
|
||||
placeholder="eg. 250"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<SelectInput
|
||||
label="Sampler"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].sample.sampler}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
|
||||
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Guidance Scale"
|
||||
value={jobConfig.config.process[0].sample.guidance_scale}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Sample Steps"
|
||||
value={jobConfig.config.process[0].sample.sample_steps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_steps')}
|
||||
placeholder="eg. 1"
|
||||
className="pt-2"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Width"
|
||||
value={jobConfig.config.process[0].sample.width}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
|
||||
placeholder="eg. 1024"
|
||||
min={256}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Height"
|
||||
value={jobConfig.config.process[0].sample.height}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
|
||||
placeholder="eg. 1024"
|
||||
className="pt-2"
|
||||
min={256}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Seed"
|
||||
value={jobConfig.config.process[0].sample.seed}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.seed')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<Checkbox
|
||||
label="Walk Seed"
|
||||
className="pt-4 pl-2"
|
||||
checked={jobConfig.config.process[0].sample.walk_seed}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<FormGroup
|
||||
label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`}
|
||||
className="pt-2"
|
||||
>
|
||||
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
|
||||
<div key={i} className="flex items-center space-x-2">
|
||||
<div className="flex-1">
|
||||
<TextInput
|
||||
value={prompt}
|
||||
onChange={value => setJobConfig(value, `config.process[0].sample.prompts[${i}]`)}
|
||||
placeholder="Enter prompt"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
jobConfig.config.process[0].sample.prompts.filter((_, index) => index !== i),
|
||||
'config.process[0].sample.prompts',
|
||||
)
|
||||
}
|
||||
className="rounded-full p-1 text-sm"
|
||||
>
|
||||
<X />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
[...jobConfig.config.process[0].sample.prompts, ''],
|
||||
'config.process[0].sample.prompts',
|
||||
)
|
||||
}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||
>
|
||||
Add Prompt
|
||||
</button>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
{status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>}
|
||||
{status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>}
|
||||
</form>
|
||||
<div className="pt-20"></div>
|
||||
</MainContent>
|
||||
</>
|
||||
);
|
||||
}
|
||||
29
ui/src/app/jobs/page.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
'use client';
|
||||
|
||||
import JobsTable from '@/components/JobsTable';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import Link from 'next/link';
|
||||
|
||||
export default function Dashboard() {
|
||||
return (
|
||||
<>
|
||||
<TopBar>
|
||||
<div>
|
||||
<h1 className="text-lg">Training Jobs</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
<div>
|
||||
<Link
|
||||
href="/jobs/new"
|
||||
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
|
||||
>
|
||||
New Training Job
|
||||
</Link>
|
||||
</div>
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
<JobsTable />
|
||||
</MainContent>
|
||||
</>
|
||||
);
|
||||
}
|
||||
38
ui/src/app/layout.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import type { Metadata } from 'next';
|
||||
import { Inter } from 'next/font/google';
|
||||
import './globals.css';
|
||||
import Sidebar from '@/components/Sidebar';
|
||||
import { ThemeProvider } from '@/components/ThemeProvider';
|
||||
import ConfirmModal from '@/components/ConfirmModal';
|
||||
import SampleImageModal from '@/components/SampleImageModal';
|
||||
import { Suspense } from 'react';
|
||||
|
||||
const inter = Inter({ subsets: ['latin'] });
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: 'Ostris - AI Toolkit',
|
||||
description: 'A toolkit for building AI things.',
|
||||
};
|
||||
|
||||
export default function RootLayout({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<html lang="en" className="dark">
|
||||
<head>
|
||||
<meta name="apple-mobile-web-app-title" content="AI-Toolkit" />
|
||||
</head>
|
||||
<body className={inter.className}>
|
||||
<ThemeProvider>
|
||||
<div className="flex h-screen bg-gray-950">
|
||||
<Sidebar />
|
||||
|
||||
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
|
||||
<Suspense>{children}</Suspense>
|
||||
</main>
|
||||
</div>
|
||||
</ThemeProvider>
|
||||
<ConfirmModal />
|
||||
<SampleImageModal />
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
21
ui/src/app/manifest.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "AI Toolkit",
|
||||
"short_name": "AIToolkit",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/web-app-manifest-192x192.png",
|
||||
"sizes": "192x192",
|
||||
"type": "image/png",
|
||||
"purpose": "maskable"
|
||||
},
|
||||
{
|
||||
"src": "/web-app-manifest-512x512.png",
|
||||
"sizes": "512x512",
|
||||
"type": "image/png",
|
||||
"purpose": "maskable"
|
||||
}
|
||||
],
|
||||
"theme_color": "#000000",
|
||||
"background_color": "#000000",
|
||||
"display": "standalone"
|
||||
}
|
||||
5
ui/src/app/page.tsx
Normal file
@@ -0,0 +1,5 @@
|
||||
import { redirect } from 'next/navigation';
|
||||
|
||||
export default function Home() {
|
||||
redirect('/dashboard');
|
||||
}
|
||||
134
ui/src/app/settings/page.tsx
Normal file
@@ -0,0 +1,134 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import useSettings from '@/hooks/useSettings';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
|
||||
export default function Settings() {
|
||||
const { settings, setSettings } = useSettings();
|
||||
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setStatus('saving');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/settings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(settings),
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error('Failed to save settings');
|
||||
|
||||
setStatus('success');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
} catch (error) {
|
||||
console.error('Error saving settings:', error);
|
||||
setStatus('error');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
}
|
||||
};
|
||||
|
||||
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const { name, value } = e.target;
|
||||
setSettings(prev => ({ ...prev, [name]: value }));
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<TopBar>
|
||||
<div>
|
||||
<h1 className="text-lg">Settings</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
<div className="grid grid-cols-1 gap-6 sm:grid-cols-2">
|
||||
<div>
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label htmlFor="HF_TOKEN" className="block text-sm font-medium mb-2">
|
||||
Hugging Face Token
|
||||
<div className="text-gray-500 text-sm ml-1">
|
||||
Create a Read token on{' '}
|
||||
<a href="https://huggingface.co/settings/tokens" target="_blank" rel="noreferrer">
|
||||
{' '}
|
||||
Huggingface
|
||||
</a>{' '}
|
||||
if you need to access gated/private models.
|
||||
</div>
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
id="HF_TOKEN"
|
||||
name="HF_TOKEN"
|
||||
value={settings.HF_TOKEN}
|
||||
onChange={handleChange}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
||||
placeholder="Enter your Hugging Face token"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label htmlFor="TRAINING_FOLDER" className="block text-sm font-medium mb-2">
|
||||
Training Folder Path
|
||||
<div className="text-gray-500 text-sm ml-1">
|
||||
We will store your training information here. Must be an absolute path. If blank, it will default
|
||||
to the output folder in the project root.
|
||||
</div>
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
id="TRAINING_FOLDER"
|
||||
name="TRAINING_FOLDER"
|
||||
value={settings.TRAINING_FOLDER}
|
||||
onChange={handleChange}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
||||
placeholder="Enter training folder path"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label htmlFor="DATASETS_FOLDER" className="block text-sm font-medium mb-2">
|
||||
Dataset Folder Path
|
||||
<div className="text-gray-500 text-sm ml-1">
|
||||
Where we store and find your datasets.{' '}
|
||||
<span className="text-orange-800">
|
||||
Warning: This software may modify datasets so it is recommended you keep a backup somewhere else
|
||||
or have a dedicated folder for this software.
|
||||
</span>
|
||||
</div>
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
id="DATASETS_FOLDER"
|
||||
name="DATASETS_FOLDER"
|
||||
value={settings.DATASETS_FOLDER}
|
||||
onChange={handleChange}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
||||
placeholder="Enter datasets folder path"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={status === 'saving'}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{status === 'saving' ? 'Saving...' : 'Save Settings'}
|
||||
</button>
|
||||
|
||||
{status === 'success' && <p className="text-green-500 text-center">Settings saved successfully!</p>}
|
||||
{status === 'error' && <p className="text-red-500 text-center">Error saving settings. Please try again.</p>}
|
||||
</form>
|
||||
</MainContent>
|
||||
</>
|
||||
);
|
||||
}
|
||||
155
ui/src/components/AddImagesModal.tsx
Normal file
@@ -0,0 +1,155 @@
|
||||
'use client';
|
||||
import { createGlobalState } from 'react-global-hooks';
|
||||
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import axios from 'axios';
|
||||
|
||||
export interface AddImagesModalState {
|
||||
datasetName: string;
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export const addImagesModalState = createGlobalState<AddImagesModalState | null>(null);
|
||||
|
||||
export const openImagesModal = (datasetName: string, onComplete: () => void) => {
|
||||
addImagesModalState.set({ datasetName, onComplete });
|
||||
}
|
||||
|
||||
export default function AddImagesModal() {
|
||||
const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use();
|
||||
const [uploadProgress, setUploadProgress] = useState<number>(0);
|
||||
const [isUploading, setIsUploading] = useState<boolean>(false);
|
||||
const open = addImagesModalInfo !== null;
|
||||
|
||||
const onCancel = () => {
|
||||
if (!isUploading) {
|
||||
setAddImagesModalInfo(null);
|
||||
}
|
||||
};
|
||||
|
||||
const onDone = () => {
|
||||
if (addImagesModalInfo?.onComplete && !isUploading) {
|
||||
addImagesModalInfo.onComplete();
|
||||
setAddImagesModalInfo(null);
|
||||
}
|
||||
};
|
||||
|
||||
const onDrop = useCallback(async (acceptedFiles: File[]) => {
|
||||
if (acceptedFiles.length === 0) return;
|
||||
|
||||
setIsUploading(true);
|
||||
setUploadProgress(0);
|
||||
|
||||
const formData = new FormData();
|
||||
acceptedFiles.forEach(file => {
|
||||
formData.append('files', file);
|
||||
});
|
||||
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
|
||||
|
||||
try {
|
||||
await axios.post(`/api/datasets/upload`, formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
onUploadProgress: (progressEvent) => {
|
||||
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
|
||||
setUploadProgress(percentCompleted);
|
||||
},
|
||||
timeout: 0, // Disable timeout
|
||||
});
|
||||
|
||||
onDone();
|
||||
} catch (error) {
|
||||
console.error('Upload failed:', error);
|
||||
} finally {
|
||||
setIsUploading(false);
|
||||
setUploadProgress(0);
|
||||
}
|
||||
}, [addImagesModalInfo]);
|
||||
|
||||
const { getRootProps, getInputProps, isDragActive } = useDropzone({
|
||||
onDrop,
|
||||
accept: {
|
||||
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp'],
|
||||
'text/*': ['.txt']
|
||||
},
|
||||
multiple: true
|
||||
});
|
||||
|
||||
return (
|
||||
<Dialog open={open} onClose={onCancel} className="relative z-10">
|
||||
<DialogBackdrop
|
||||
transition
|
||||
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
|
||||
/>
|
||||
|
||||
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
|
||||
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
|
||||
<DialogPanel
|
||||
transition
|
||||
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
|
||||
>
|
||||
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
|
||||
<div className="text-center">
|
||||
<DialogTitle as="h3" className="text-base font-semibold text-gray-200 mb-4">
|
||||
Add Images to: {addImagesModalInfo?.datasetName}
|
||||
</DialogTitle>
|
||||
<div className="w-full">
|
||||
<div
|
||||
{...getRootProps()}
|
||||
className={`h-40 w-full flex flex-col items-center justify-center border-2 border-dashed rounded-lg cursor-pointer transition-colors duration-200
|
||||
${isDragActive ? 'border-blue-500 bg-blue-50/10' : 'border-gray-600'}`}
|
||||
>
|
||||
<input {...getInputProps()} />
|
||||
<FaUpload className="size-8 mb-3 text-gray-400" />
|
||||
<p className="text-sm text-gray-200 text-center">
|
||||
{isDragActive
|
||||
? 'Drop the files here...'
|
||||
: 'Drag & drop files here, or click to select files'}
|
||||
</p>
|
||||
</div>
|
||||
{isUploading && (
|
||||
<div className="mt-4">
|
||||
<div className="w-full bg-gray-700 rounded-full h-2.5">
|
||||
<div
|
||||
className="bg-blue-600 h-2.5 rounded-full"
|
||||
style={{ width: `${uploadProgress}%` }}
|
||||
></div>
|
||||
</div>
|
||||
<p className="text-sm text-gray-300 mt-2 text-center">
|
||||
Uploading... {uploadProgress}%
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onDone}
|
||||
disabled={isUploading}
|
||||
className={`inline-flex w-full justify-center rounded-md bg-slate-600 px-3 py-2 text-sm font-semibold text-white shadow-xs sm:ml-3 sm:w-auto
|
||||
${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
|
||||
>
|
||||
Done
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
data-autofocus
|
||||
onClick={onCancel}
|
||||
disabled={isUploading}
|
||||
className={`mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0
|
||||
${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
</DialogPanel>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
15
ui/src/components/Card.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
interface CardProps {
|
||||
title?: string;
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
|
||||
const Card: React.FC<CardProps> = ({ title, children }) => {
|
||||
return (
|
||||
<section className="space-y-2 px-4 pb-4 pt-2 bg-gray-900 rounded-lg">
|
||||
{title && <h2 className="text-lg mb-2 font-semibold uppercase text-gray-500">{title}</h2>}
|
||||
{children ? children : null}
|
||||
</section>
|
||||
);
|
||||
};
|
||||
|
||||
export default Card;
|
||||
171
ui/src/components/ConfirmModal.tsx
Normal file
@@ -0,0 +1,171 @@
|
||||
'use client';
|
||||
import { useState, useEffect } from 'react';
|
||||
import { createGlobalState } from 'react-global-hooks';
|
||||
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
|
||||
import { FaExclamationTriangle, FaInfo } from 'react-icons/fa';
|
||||
|
||||
export interface ConfirmState {
|
||||
title: string;
|
||||
message?: string;
|
||||
confirmText?: string;
|
||||
type?: 'danger' | 'warning' | 'info';
|
||||
onConfirm?: () => void;
|
||||
onCancel?: () => void;
|
||||
}
|
||||
|
||||
export const confirmstate = createGlobalState<ConfirmState | null>(null);
|
||||
|
||||
export const openConfirm = (confirmProps: ConfirmState) => {
|
||||
confirmstate.set(confirmProps);
|
||||
};
|
||||
|
||||
export default function ConfirmModal() {
|
||||
const [confirm, setConfirm] = confirmstate.use();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (confirm) {
|
||||
setIsOpen(true);
|
||||
}
|
||||
}, [confirm]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen) {
|
||||
// use timeout to allow the dialog to close before resetting the state
|
||||
setTimeout(() => {
|
||||
setConfirm(null);
|
||||
}, 500);
|
||||
}
|
||||
}, [isOpen]);
|
||||
|
||||
const onCancel = () => {
|
||||
if (confirm?.onCancel) {
|
||||
confirm.onCancel();
|
||||
}
|
||||
setIsOpen(false);
|
||||
};
|
||||
|
||||
const onConfirm = () => {
|
||||
if (confirm?.onConfirm) {
|
||||
confirm.onConfirm();
|
||||
}
|
||||
setIsOpen(false);
|
||||
};
|
||||
|
||||
let Icon = FaExclamationTriangle;
|
||||
let color = confirm?.type || 'danger';
|
||||
|
||||
// Use conditional rendering for icon
|
||||
if (color === 'info') {
|
||||
Icon = FaInfo;
|
||||
}
|
||||
|
||||
// Color mapping for background colors
|
||||
const getBgColor = () => {
|
||||
switch (color) {
|
||||
case 'danger':
|
||||
return 'bg-red-500';
|
||||
case 'warning':
|
||||
return 'bg-yellow-500';
|
||||
case 'info':
|
||||
return 'bg-blue-500';
|
||||
default:
|
||||
return 'bg-red-500';
|
||||
}
|
||||
};
|
||||
|
||||
// Color mapping for text colors
|
||||
const getTextColor = () => {
|
||||
switch (color) {
|
||||
case 'danger':
|
||||
return 'text-red-950';
|
||||
case 'warning':
|
||||
return 'text-yellow-950';
|
||||
case 'info':
|
||||
return 'text-blue-950';
|
||||
default:
|
||||
return 'text-red-950';
|
||||
}
|
||||
};
|
||||
|
||||
// Color mapping for titles
|
||||
const getTitleColor = () => {
|
||||
switch (color) {
|
||||
case 'danger':
|
||||
return 'text-red-500';
|
||||
case 'warning':
|
||||
return 'text-yellow-500';
|
||||
case 'info':
|
||||
return 'text-blue-500';
|
||||
default:
|
||||
return 'text-red-500';
|
||||
}
|
||||
};
|
||||
|
||||
// Button background color mapping
|
||||
const getButtonBgColor = () => {
|
||||
switch (color) {
|
||||
case 'danger':
|
||||
return 'bg-red-700 hover:bg-red-500';
|
||||
case 'warning':
|
||||
return 'bg-yellow-700 hover:bg-yellow-500';
|
||||
case 'info':
|
||||
return 'bg-blue-700 hover:bg-blue-500';
|
||||
default:
|
||||
return 'bg-red-700 hover:bg-red-500';
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onClose={onCancel} className="relative z-10">
|
||||
<DialogBackdrop
|
||||
transition
|
||||
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
|
||||
/>
|
||||
|
||||
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
|
||||
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
|
||||
<DialogPanel
|
||||
transition
|
||||
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
|
||||
>
|
||||
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
|
||||
<div className="sm:flex sm:items-start">
|
||||
<div
|
||||
className={`mx-auto flex size-12 shrink-0 items-center justify-center rounded-full ${getBgColor()} sm:mx-0 sm:size-10`}
|
||||
>
|
||||
<Icon aria-hidden="true" className={`size-6 ${getTextColor()}`} />
|
||||
</div>
|
||||
<div className="mt-3 text-center sm:mt-0 sm:ml-4 sm:text-left">
|
||||
<DialogTitle as="h3" className={`text-base font-semibold ${getTitleColor()}`}>
|
||||
{confirm?.title}
|
||||
</DialogTitle>
|
||||
<div className="mt-2">
|
||||
<p className="text-sm text-gray-200">{confirm?.message}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onConfirm}
|
||||
className={`inline-flex w-full justify-center rounded-md ${getButtonBgColor()} px-3 py-2 text-sm font-semibold text-white shadow-xs sm:ml-3 sm:w-auto`}
|
||||
>
|
||||
{confirm?.confirmText || 'Confirm'}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
data-autofocus
|
||||
onClick={onCancel}
|
||||
className="mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
</DialogPanel>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
186
ui/src/components/DatasetImageCard.tsx
Normal file
@@ -0,0 +1,186 @@
|
||||
import React, { useRef, useEffect, useState, ReactNode, KeyboardEvent } from 'react';
|
||||
import { FaTrashAlt } from 'react-icons/fa';
|
||||
import { openConfirm } from './ConfirmModal';
|
||||
import classNames from 'classnames';
|
||||
|
||||
interface DatasetImageCardProps {
|
||||
imageUrl: string;
|
||||
alt: string;
|
||||
children?: ReactNode;
|
||||
className?: string;
|
||||
onDelete?: () => void;
|
||||
}
|
||||
|
||||
const DatasetImageCard: React.FC<DatasetImageCardProps> = ({
|
||||
imageUrl,
|
||||
alt,
|
||||
children,
|
||||
className = '',
|
||||
onDelete = () => {},
|
||||
}) => {
|
||||
const cardRef = useRef<HTMLDivElement>(null);
|
||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||
const [loaded, setLoaded] = useState<boolean>(false);
|
||||
const [isCaptionLoaded, setIsCaptionLoaded] = useState<boolean>(false);
|
||||
const [caption, setCaption] = useState<string>('');
|
||||
const [savedCaption, setSavedCaption] = useState<string>('');
|
||||
const isGettingCaption = useRef<boolean>(false);
|
||||
|
||||
const fetchCaption = async () => {
|
||||
try {
|
||||
if (isGettingCaption.current || isCaptionLoaded) return;
|
||||
isGettingCaption.current = true;
|
||||
const response = await fetch(`/api/caption/${encodeURIComponent(imageUrl)}`);
|
||||
const data = await response.text();
|
||||
setCaption(data);
|
||||
setSavedCaption(data);
|
||||
setIsCaptionLoaded(true);
|
||||
} catch (error) {
|
||||
console.error('Error fetching caption:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const saveCaption = () => {
|
||||
const trimmedCaption = caption.trim();
|
||||
if (trimmedCaption === savedCaption) return;
|
||||
fetch('/api/img/caption', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ imgPath: imageUrl, caption: trimmedCaption }),
|
||||
})
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Caption saved:', data);
|
||||
setSavedCaption(trimmedCaption);
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error saving caption:', error);
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
isVisible && fetchCaption();
|
||||
}, [isVisible]);
|
||||
|
||||
useEffect(() => {
|
||||
// Create intersection observer to check visibility
|
||||
const observer = new IntersectionObserver(
|
||||
entries => {
|
||||
if (entries[0].isIntersecting) {
|
||||
setIsVisible(true);
|
||||
observer.disconnect();
|
||||
}
|
||||
},
|
||||
{ threshold: 0.1 },
|
||||
);
|
||||
|
||||
if (cardRef.current) {
|
||||
observer.observe(cardRef.current);
|
||||
}
|
||||
|
||||
return () => {
|
||||
observer.disconnect();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleLoad = (): void => {
|
||||
setLoaded(true);
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>): void => {
|
||||
// If Enter is pressed without Shift, prevent default behavior and save
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
saveCaption();
|
||||
}
|
||||
};
|
||||
|
||||
const isCaptionCurrent = caption.trim() === savedCaption;
|
||||
|
||||
return (
|
||||
<div className={`flex flex-col ${className}`}>
|
||||
{/* Square image container */}
|
||||
<div
|
||||
ref={cardRef}
|
||||
className="relative w-full"
|
||||
style={{ paddingBottom: '100%' }} // Make it square
|
||||
>
|
||||
<div className="absolute inset-0 rounded-t-lg shadow-md">
|
||||
{isVisible && (
|
||||
<img
|
||||
src={`/api/img/${encodeURIComponent(imageUrl)}`}
|
||||
alt={alt}
|
||||
onLoad={handleLoad}
|
||||
className={`w-full h-full object-contain transition-opacity duration-300 ${
|
||||
loaded ? 'opacity-100' : 'opacity-0'
|
||||
}`}
|
||||
/>
|
||||
)}
|
||||
{children && <div className="absolute inset-0 flex items-center justify-center">{children}</div>}
|
||||
<div className="absolute top-1 right-1">
|
||||
<button
|
||||
className="bg-gray-800 rounded-full p-2"
|
||||
onClick={() => {
|
||||
openConfirm({
|
||||
title: 'Delete Image',
|
||||
message: 'Are you sure you want to delete this image? This action cannot be undone.',
|
||||
type: 'warning',
|
||||
confirmText: 'Delete',
|
||||
onConfirm: () => {
|
||||
fetch('/api/img/delete', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ imgPath: imageUrl }),
|
||||
})
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Image deleted:', data);
|
||||
onDelete();
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error deleting image:', error);
|
||||
});
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
<FaTrashAlt />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Text area below the image */}
|
||||
<div
|
||||
className={classNames('w-full p-2 bg-gray-800 text-white text-sm rounded-b-lg h-[75px]', {
|
||||
'border-blue-500 border-2': !isCaptionCurrent,
|
||||
'border-transparent border-2': isCaptionCurrent,
|
||||
})}
|
||||
>
|
||||
{isVisible && isCaptionLoaded && (
|
||||
<form
|
||||
onSubmit={e => {
|
||||
e.preventDefault();
|
||||
saveCaption();
|
||||
}}
|
||||
onBlur={saveCaption}
|
||||
>
|
||||
<textarea
|
||||
className="w-full bg-transparent resize-none outline-none focus:ring-0 focus:outline-none"
|
||||
value={caption}
|
||||
rows={3}
|
||||
onChange={e => setCaption(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
/>
|
||||
</form>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default DatasetImageCard;
|
||||
90
ui/src/components/FilesWidget.tsx
Normal file
@@ -0,0 +1,90 @@
|
||||
import React from 'react';
|
||||
import useFilesList from '@/hooks/useFilesList';
|
||||
import Link from 'next/link';
|
||||
import { Loader2, AlertCircle, Download, Box, Brain } from 'lucide-react';
|
||||
|
||||
export default function FilesWidget({ jobID }: { jobID: string }) {
|
||||
const { files, status, refreshFiles } = useFilesList(jobID, 5000);
|
||||
|
||||
const cleanSize = (size: number) => {
|
||||
if (size < 1024) {
|
||||
return `${size} B`;
|
||||
} else if (size < 1024 * 1024) {
|
||||
return `${(size / 1024).toFixed(1)} KB`;
|
||||
} else if (size < 1024 * 1024 * 1024) {
|
||||
return `${(size / (1024 * 1024)).toFixed(1)} MB`;
|
||||
} else {
|
||||
return `${(size / (1024 * 1024 * 1024)).toFixed(1)} GB`;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden hover:shadow-2xl transition-all duration-300 border border-gray-800">
|
||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<Brain className="w-5 h-5 text-purple-400" />
|
||||
<h2 className="font-semibold text-gray-100">Checkpoints</h2>
|
||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
|
||||
{files.length}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="p-2">
|
||||
{status === 'loading' && (
|
||||
<div className="flex items-center justify-center py-4">
|
||||
<Loader2 className="w-5 h-5 text-gray-400 animate-spin" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{status === 'error' && (
|
||||
<div className="flex items-center justify-center py-4 text-rose-400 space-x-2">
|
||||
<AlertCircle className="w-4 h-4" />
|
||||
<span className="text-sm">Error loading checkpoints</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{['success', 'refreshing'].includes(status) && (
|
||||
<div className="space-y-1">
|
||||
{files.map((file, index) => {
|
||||
const fileName = file.path.split('/').pop() || '';
|
||||
const nameWithoutExt = fileName.replace('.safetensors', '');
|
||||
return (
|
||||
<a
|
||||
key={index}
|
||||
target='_blank'
|
||||
href={`/api/files/${encodeURIComponent(file.path)}`}
|
||||
className="group flex items-center justify-between px-2 py-1.5 rounded-lg hover:bg-gray-800 transition-all duration-200"
|
||||
>
|
||||
<div className="flex items-center space-x-2 min-w-0">
|
||||
<Box className="w-4 h-4 text-purple-400 flex-shrink-0" />
|
||||
<div className="flex flex-col min-w-0">
|
||||
<div className="flex text-sm text-gray-200">
|
||||
<span className="overflow-hidden text-ellipsis direction-rtl whitespace-nowrap">
|
||||
{nameWithoutExt}
|
||||
</span>
|
||||
</div>
|
||||
<span className="text-xs text-gray-500">.safetensors</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center space-x-3 flex-shrink-0">
|
||||
<span className="text-xs text-gray-400">{cleanSize(file.size)}</span>
|
||||
<div className="bg-purple-500 bg-opacity-0 group-hover:bg-opacity-10 rounded-full p-1 transition-all">
|
||||
<Download className="w-3 h-3 text-purple-400" />
|
||||
</div>
|
||||
</div>
|
||||
</a>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{['success', 'refreshing'].includes(status) && files.length === 0 && (
|
||||
<div className="text-center py-4 text-gray-400 text-sm">
|
||||
No checkpoints available
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
123
ui/src/components/GPUMonitor.tsx
Normal file
@@ -0,0 +1,123 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { GPUApiResponse } from '@/types';
|
||||
import Loading from '@/components/Loading';
|
||||
import GPUWidget from '@/components/GPUWidget';
|
||||
|
||||
const GpuMonitor: React.FC = () => {
|
||||
const [gpuData, setGpuData] = useState<GPUApiResponse | null>(null);
|
||||
const [loading, setLoading] = useState<boolean>(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [lastUpdated, setLastUpdated] = useState<Date | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchGpuInfo = async () => {
|
||||
try {
|
||||
const response = await fetch('/api/gpu');
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data: GPUApiResponse = await response.json();
|
||||
setGpuData(data);
|
||||
setLastUpdated(new Date());
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch immediately on component mount
|
||||
fetchGpuInfo();
|
||||
|
||||
// Set up interval to fetch every 1 seconds
|
||||
const intervalId = setInterval(fetchGpuInfo, 1000);
|
||||
|
||||
// Clean up interval on component unmount
|
||||
return () => clearInterval(intervalId);
|
||||
}, []);
|
||||
|
||||
const getGridClasses = (gpuCount: number): string => {
|
||||
switch (gpuCount) {
|
||||
case 1:
|
||||
return 'grid-cols-1';
|
||||
case 2:
|
||||
return 'grid-cols-2';
|
||||
case 3:
|
||||
return 'grid-cols-3';
|
||||
case 4:
|
||||
return 'grid-cols-4';
|
||||
case 5:
|
||||
case 6:
|
||||
return 'grid-cols-3';
|
||||
case 7:
|
||||
case 8:
|
||||
return 'grid-cols-4';
|
||||
case 9:
|
||||
case 10:
|
||||
return 'grid-cols-5';
|
||||
default:
|
||||
return 'grid-cols-3';
|
||||
}
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return <Loading />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded relative" role="alert">
|
||||
<strong className="font-bold">Error!</strong>
|
||||
<span className="block sm:inline"> {error}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!gpuData) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<span className="block sm:inline">No GPU data available.</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!gpuData.hasNvidiaSmi) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
|
||||
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
|
||||
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (gpuData.gpus.length === 0) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const gridClass = getGridClasses(gpuData.gpus.length);
|
||||
|
||||
return (
|
||||
<div className="w-full">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
<h1 className="text-md">GPU Monitor</h1>
|
||||
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
|
||||
</div>
|
||||
|
||||
<div className={`grid ${gridClass} gap-3`}>
|
||||
{gpuData.gpus.map((gpu, idx) => (
|
||||
<GPUWidget key={idx} gpu={gpu} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default GpuMonitor;
|
||||
110
ui/src/components/GPUWidget.tsx
Normal file
@@ -0,0 +1,110 @@
|
||||
import React from 'react';
|
||||
import { GpuInfo } from '@/types';
|
||||
import { ChevronRight, Thermometer, Zap, Clock, HardDrive, Fan, Cpu } from 'lucide-react';
|
||||
|
||||
interface GPUWidgetProps {
|
||||
gpu: GpuInfo;
|
||||
}
|
||||
|
||||
export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
||||
const formatMemory = (mb: number): string => {
|
||||
return mb >= 1024 ? `${(mb / 1024).toFixed(1)} GB` : `${mb} MB`;
|
||||
};
|
||||
|
||||
const getUtilizationColor = (value: number): string => {
|
||||
return value < 30 ? 'bg-emerald-500' : value < 70 ? 'bg-amber-500' : 'bg-rose-500';
|
||||
};
|
||||
|
||||
const getTemperatureColor = (temp: number): string => {
|
||||
return temp < 50 ? 'text-emerald-500' : temp < 80 ? 'text-amber-500' : 'text-rose-500';
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="bg-gray-900 rounded-xl shadow-lg overflow-hidden hover:shadow-2xl transition-all duration-300 border border-gray-800">
|
||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<h2 className="font-semibold text-gray-100">{gpu.name}</h2>
|
||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
|
||||
#{gpu.index}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="p-4 space-y-4">
|
||||
{/* Temperature, Fan, and Utilization Section */}
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center space-x-2">
|
||||
<Thermometer className={`w-4 h-4 ${getTemperatureColor(gpu.temperature)}`} />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Temperature</p>
|
||||
<p className={`text-sm font-medium ${getTemperatureColor(gpu.temperature)}`}>
|
||||
{gpu.temperature}°C
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Fan className="w-4 h-4 text-blue-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Fan Speed</p>
|
||||
<p className="text-sm font-medium text-blue-400">
|
||||
{gpu.fan.speed}%
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className="flex items-center space-x-2 mb-1">
|
||||
<Cpu className="w-4 h-4 text-gray-400" />
|
||||
<p className="text-xs text-gray-400">GPU Load</p>
|
||||
<span className="text-xs text-gray-300 ml-auto">{gpu.utilization.gpu}%</span>
|
||||
</div>
|
||||
<div className="w-full bg-gray-700 rounded-full h-1">
|
||||
<div
|
||||
className={`h-1 rounded-full transition-all ${getUtilizationColor(gpu.utilization.gpu)}`}
|
||||
style={{ width: `${gpu.utilization.gpu}%` }}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2 mb-1 mt-3">
|
||||
<HardDrive className="w-4 h-4 text-blue-400" />
|
||||
<p className="text-xs text-gray-400">Memory</p>
|
||||
<span className="text-xs text-gray-300 ml-auto">
|
||||
{((gpu.memory.used / gpu.memory.total) * 100).toFixed(1)}%
|
||||
</span>
|
||||
</div>
|
||||
<div className="w-full bg-gray-700 rounded-full h-1">
|
||||
<div
|
||||
className="h-1 rounded-full bg-blue-500 transition-all"
|
||||
style={{ width: `${(gpu.memory.used / gpu.memory.total) * 100}%` }}
|
||||
/>
|
||||
</div>
|
||||
<p className="text-xs text-gray-400 mt-0.5">
|
||||
{formatMemory(gpu.memory.used)} / {formatMemory(gpu.memory.total)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Power and Clocks Section */}
|
||||
<div className="grid grid-cols-2 gap-4 pt-2 border-t border-gray-800">
|
||||
<div className="flex items-start space-x-2">
|
||||
<Clock className="w-4 h-4 text-purple-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Clock Speed</p>
|
||||
<p className="text-sm text-gray-200">{gpu.clocks.graphics} MHz</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-start space-x-2">
|
||||
<Zap className="w-4 h-4 text-amber-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Power Draw</p>
|
||||
<p className="text-sm text-gray-200">
|
||||
{gpu.power.draw.toFixed(1)}W
|
||||
<span className="text-gray-400 text-xs"> / {gpu.power.limit.toFixed(1)}W</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
87
ui/src/components/JobActionBar.tsx
Normal file
@@ -0,0 +1,87 @@
|
||||
import Link from 'next/link';
|
||||
import { Eye, Trash2, Pen, Play, Pause } from 'lucide-react';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { openConfirm } from '@/components/ConfirmModal';
|
||||
import { Job } from '@prisma/client';
|
||||
import { startJob, stopJob, deleteJob, getAvaliableJobActions } from '@/utils/jobs';
|
||||
|
||||
interface JobActionBarProps {
|
||||
job: Job;
|
||||
onRefresh?: () => void;
|
||||
afterDelete?: () => void;
|
||||
hideView?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function JobActionBar({ job, onRefresh, afterDelete, className, hideView }: JobActionBarProps) {
|
||||
const { canStart, canStop, canDelete, canEdit } = getAvaliableJobActions(job);
|
||||
|
||||
if (!afterDelete) afterDelete = onRefresh;
|
||||
|
||||
return (
|
||||
<div className={`${className}`}>
|
||||
{canStart && (
|
||||
<Button
|
||||
onClick={async () => {
|
||||
if (!canStart) return;
|
||||
await startJob(job.id);
|
||||
if (onRefresh) onRefresh();
|
||||
}}
|
||||
className={`ml-2 opacity-100`}
|
||||
>
|
||||
<Play />
|
||||
</Button>
|
||||
)}
|
||||
{canStop && (
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (!canStop) return;
|
||||
openConfirm({
|
||||
title: 'Stop Job',
|
||||
message: `Are you sure you want to stop the job "${job.name}"? You CAN resume later.`,
|
||||
type: 'info',
|
||||
confirmText: 'Stop',
|
||||
onConfirm: async () => {
|
||||
await stopJob(job.id);
|
||||
if (onRefresh) onRefresh();
|
||||
},
|
||||
});
|
||||
}}
|
||||
className={`ml-2 opacity-100`}
|
||||
>
|
||||
<Pause />
|
||||
</Button>
|
||||
)}
|
||||
{!hideView && (
|
||||
<Link href={`/jobs/${job.id}`} className="ml-2 text-gray-200 hover:text-gray-100 inline-block">
|
||||
<Eye />
|
||||
</Link>
|
||||
)}
|
||||
{canEdit && (
|
||||
<Link href={`/jobs/new?id=${job.id}`} className="ml-2 hover:text-gray-100 inline-block">
|
||||
<Pen />
|
||||
</Link>
|
||||
)}
|
||||
{canDelete && (
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (!canDelete) return;
|
||||
openConfirm({
|
||||
title: 'Delete Job',
|
||||
message: `Are you sure you want to delete the job "${job.name}"? This will also permanently remove it from your disk.`,
|
||||
type: 'warning',
|
||||
confirmText: 'Delete',
|
||||
onConfirm: async () => {
|
||||
await deleteJob(job.id);
|
||||
if (afterDelete) afterDelete();
|
||||
},
|
||||
});
|
||||
}}
|
||||
className={`ml-2 opacity-100`}
|
||||
>
|
||||
<Trash2 />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
104
ui/src/components/JobOverview.tsx
Normal file
@@ -0,0 +1,104 @@
|
||||
import { Job } from '@prisma/client';
|
||||
import useGPUInfo from '@/hooks/useGPUInfo';
|
||||
import GPUWidget from '@/components/GPUWidget';
|
||||
import FilesWidget from '@/components/FilesWidget';
|
||||
import { getTotalSteps } from '@/utils/jobs';
|
||||
import { Cpu, HardDrive, Info, Gauge } from 'lucide-react';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
interface JobOverviewProps {
|
||||
job: Job;
|
||||
}
|
||||
|
||||
export default function JobOverview({ job }: JobOverviewProps) {
|
||||
const gpuIds = useMemo(() => job.gpu_ids.split(',').map(id => parseInt(id)), [job.gpu_ids]);
|
||||
|
||||
const { gpuList, isGPUInfoLoaded } = useGPUInfo(gpuIds, 5000);
|
||||
const totalSteps = getTotalSteps(job);
|
||||
const progress = (job.step / totalSteps) * 100;
|
||||
const isStopping = job.stop && job.status === 'running';
|
||||
|
||||
const getStatusColor = (status: string) => {
|
||||
switch (status.toLowerCase()) {
|
||||
case 'running':
|
||||
return 'bg-emerald-500/10 text-emerald-500';
|
||||
case 'stopping':
|
||||
return 'bg-amber-500/10 text-amber-500';
|
||||
case 'stopped':
|
||||
return 'bg-gray-500/10 text-gray-400';
|
||||
case 'completed':
|
||||
return 'bg-blue-500/10 text-blue-500';
|
||||
case 'error':
|
||||
return 'bg-rose-500/10 text-rose-500';
|
||||
default:
|
||||
return 'bg-gray-500/10 text-gray-400';
|
||||
}
|
||||
};
|
||||
|
||||
let status = job.status;
|
||||
if (isStopping) {
|
||||
status = 'stopping';
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="grid grid-cols-1 gap-6 md:grid-cols-3">
|
||||
{/* Job Information Panel */}
|
||||
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800">
|
||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||
<h2 className="text-gray-100"><Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}</h2>
|
||||
<span className={`px-3 py-1 rounded-full text-sm ${getStatusColor(job.status)}`}>{job.status}</span>
|
||||
</div>
|
||||
|
||||
<div className="p-4 space-y-6">
|
||||
{/* Progress Bar */}
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between text-sm">
|
||||
<span className="text-gray-400">Progress</span>
|
||||
<span className="text-gray-200">
|
||||
Step {job.step} of {totalSteps}
|
||||
</span>
|
||||
</div>
|
||||
<div className="w-full bg-gray-800 rounded-full h-2">
|
||||
<div className="h-2 rounded-full bg-blue-500 transition-all" style={{ width: `${progress}%` }} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Job Info Grid */}
|
||||
<div className="grid gap-4 grid-cols-1 md:grid-cols-3">
|
||||
<div className="flex items-center space-x-4">
|
||||
<HardDrive className="w-5 h-5 text-blue-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Job Name</p>
|
||||
<p className="text-sm font-medium text-gray-200">{job.name}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center space-x-4">
|
||||
<Cpu className="w-5 h-5 text-purple-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Assigned GPUs</p>
|
||||
<p className="text-sm font-medium text-gray-200">GPUs: {job.gpu_ids}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center space-x-4">
|
||||
<Gauge className="w-5 h-5 text-green-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Speed</p>
|
||||
<p className="text-sm font-medium text-gray-200">{job.speed_string == "" ? "?" : job.speed_string}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* GPU Widget Panel */}
|
||||
<div className="col-span-1">
|
||||
<div>{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
|
||||
<div className="mt-4">
|
||||
<FilesWidget jobID={job.id} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
79
ui/src/components/JobsTable.tsx
Normal file
@@ -0,0 +1,79 @@
|
||||
import useJobsList from '@/hooks/useJobsList';
|
||||
import Link from 'next/link';
|
||||
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
|
||||
import { JobConfig } from '@/types';
|
||||
import JobActionBar from './JobActionBar';
|
||||
|
||||
interface JobsTableProps {
|
||||
onlyActive?: boolean;
|
||||
}
|
||||
|
||||
export default function JobsTable({ onlyActive = false }: JobsTableProps) {
|
||||
const { jobs, status, refreshJobs } = useJobsList(onlyActive);
|
||||
const isLoading = status === 'loading';
|
||||
|
||||
const columns: TableColumn[] = [
|
||||
{
|
||||
title: 'Name',
|
||||
key: 'name',
|
||||
render: row => (
|
||||
<Link href={`/jobs/${row.id}`} className="font-medium whitespace-nowrap">
|
||||
{row.name}
|
||||
</Link>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: 'Steps',
|
||||
key: 'steps',
|
||||
render: row => {
|
||||
const jobConfig: JobConfig = JSON.parse(row.job_config);
|
||||
const totalSteps = jobConfig.config.process[0].train.steps;
|
||||
|
||||
return (
|
||||
<div className="flex items-center">
|
||||
<span>
|
||||
{row.step} / {totalSteps}
|
||||
</span>
|
||||
<div className="w-16 bg-gray-700 rounded-full h-1.5 ml-2">
|
||||
<div
|
||||
className="bg-blue-500 h-1.5 rounded-full"
|
||||
style={{ width: `${(row.step / totalSteps) * 100}%` }}
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: 'GPU',
|
||||
key: 'gpu_ids',
|
||||
},
|
||||
{
|
||||
title: 'Status',
|
||||
key: 'status',
|
||||
render: row => {
|
||||
let statusClass = 'text-gray-400';
|
||||
if (row.status === 'completed') statusClass = 'text-green-400';
|
||||
if (row.status === 'failed') statusClass = 'text-red-400';
|
||||
if (row.status === 'running') statusClass = 'text-blue-400';
|
||||
|
||||
return <span className={statusClass}>{row.status}</span>;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: 'Info',
|
||||
key: 'info',
|
||||
className: 'truncate max-w-xs',
|
||||
},
|
||||
{
|
||||
title: 'Actions',
|
||||
key: 'actions',
|
||||
className: 'text-right',
|
||||
render: row => {
|
||||
return <JobActionBar job={row} onRefresh={refreshJobs} />;
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
return <UniversalTable columns={columns} rows={jobs} isLoading={isLoading} onRefresh={refreshJobs} />;
|
||||
}
|
||||
7
ui/src/components/Loading.tsx
Normal file
@@ -0,0 +1,7 @@
|
||||
export default function Loading() {
|
||||
return (
|
||||
<div className="flex justify-center items-center h-64">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500"></div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
110
ui/src/components/Modal.tsx
Normal file
@@ -0,0 +1,110 @@
|
||||
import React, { Fragment, useEffect } from 'react';
|
||||
|
||||
interface ModalProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
title?: string;
|
||||
children: React.ReactNode;
|
||||
showCloseButton?: boolean;
|
||||
size?: 'sm' | 'md' | 'lg' | 'xl';
|
||||
closeOnOverlayClick?: boolean;
|
||||
}
|
||||
|
||||
export const Modal: React.FC<ModalProps> = ({
|
||||
isOpen,
|
||||
onClose,
|
||||
title,
|
||||
children,
|
||||
showCloseButton = true,
|
||||
size = 'md',
|
||||
closeOnOverlayClick = true,
|
||||
}) => {
|
||||
// Close on ESC key press
|
||||
useEffect(() => {
|
||||
const handleEscKey = (e: KeyboardEvent) => {
|
||||
if (e.key === 'Escape' && isOpen) {
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
if (isOpen) {
|
||||
document.addEventListener('keydown', handleEscKey);
|
||||
// Prevent body scrolling when modal is open
|
||||
document.body.style.overflow = 'hidden';
|
||||
}
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleEscKey);
|
||||
document.body.style.overflow = 'auto';
|
||||
};
|
||||
}, [isOpen, onClose]);
|
||||
|
||||
// Handle overlay click
|
||||
const handleOverlayClick = (e: React.MouseEvent<HTMLDivElement>) => {
|
||||
if (e.target === e.currentTarget && closeOnOverlayClick) {
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
if (!isOpen) return null;
|
||||
|
||||
// Size mapping
|
||||
const sizeClasses = {
|
||||
sm: 'max-w-md',
|
||||
md: 'max-w-lg',
|
||||
lg: 'max-w-2xl',
|
||||
xl: 'max-w-4xl',
|
||||
};
|
||||
|
||||
return (
|
||||
<Fragment>
|
||||
{/* Modal backdrop */}
|
||||
<div
|
||||
className="fixed inset-0 z-50 flex items-center justify-center overflow-y-auto bg-gray-900 bg-opacity-75 backdrop-blur-sm transition-opacity"
|
||||
onClick={handleOverlayClick}
|
||||
aria-modal="true"
|
||||
role="dialog"
|
||||
aria-labelledby="modal-title"
|
||||
>
|
||||
{/* Modal panel */}
|
||||
<div
|
||||
className={`relative mx-auto w-full ${sizeClasses[size]} rounded-lg bg-gray-800 border border-gray-700 shadow-xl transition-all`}
|
||||
onClick={e => e.stopPropagation()}
|
||||
>
|
||||
{/* Modal header */}
|
||||
{(title || showCloseButton) && (
|
||||
<div className="flex items-center justify-between rounded-t-lg border-b border-gray-700 bg-gray-850 px-6 py-4">
|
||||
{title && (
|
||||
<h3 id="modal-title" className="text-xl font-semibold text-gray-100">
|
||||
{title}
|
||||
</h3>
|
||||
)}
|
||||
|
||||
{showCloseButton && (
|
||||
<button
|
||||
type="button"
|
||||
className="ml-auto inline-flex items-center justify-center rounded-md p-2 text-gray-400 hover:bg-gray-700 hover:text-gray-200 focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
onClick={onClose}
|
||||
aria-label="Close modal"
|
||||
>
|
||||
<svg
|
||||
className="h-5 w-5"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M6 18L18 6M6 6l12 12" />
|
||||
</svg>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Modal content */}
|
||||
<div className="px-6 py-4">{children}</div>
|
||||
</div>
|
||||
</div>
|
||||
</Fragment>
|
||||
);
|
||||
};
|
||||
71
ui/src/components/SampleImageCard.tsx
Normal file
@@ -0,0 +1,71 @@
|
||||
import React, { useRef, useEffect, useState, ReactNode } from 'react';
|
||||
import { sampleImageModalState } from '@/components/SampleImageModal';
|
||||
|
||||
interface SampleImageCardProps {
|
||||
imageUrl: string;
|
||||
alt: string;
|
||||
numSamples: number;
|
||||
sampleImages: string[];
|
||||
children?: ReactNode;
|
||||
className?: string;
|
||||
onDelete?: () => void;
|
||||
}
|
||||
|
||||
const SampleImageCard: React.FC<SampleImageCardProps> = ({ imageUrl, alt, numSamples, sampleImages, children, className = '' }) => {
|
||||
const cardRef = useRef<HTMLDivElement>(null);
|
||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||
const [loaded, setLoaded] = useState<boolean>(false);
|
||||
|
||||
useEffect(() => {
|
||||
// Create intersection observer to check visibility
|
||||
const observer = new IntersectionObserver(
|
||||
entries => {
|
||||
if (entries[0].isIntersecting) {
|
||||
setIsVisible(true);
|
||||
observer.disconnect();
|
||||
}
|
||||
},
|
||||
{ threshold: 0.1 },
|
||||
);
|
||||
|
||||
if (cardRef.current) {
|
||||
observer.observe(cardRef.current);
|
||||
}
|
||||
|
||||
return () => {
|
||||
observer.disconnect();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleLoad = (): void => {
|
||||
setLoaded(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`flex flex-col ${className}`}>
|
||||
{/* Square image container */}
|
||||
<div
|
||||
ref={cardRef}
|
||||
className="relative w-full cursor-pointer"
|
||||
style={{ paddingBottom: '100%' }} // Make it square
|
||||
onClick={() => sampleImageModalState.set({ imgPath: imageUrl, numSamples, sampleImages })}
|
||||
>
|
||||
<div className="absolute inset-0 rounded-t-lg shadow-md">
|
||||
{isVisible && (
|
||||
<img
|
||||
src={`/api/img/${encodeURIComponent(imageUrl)}`}
|
||||
alt={alt}
|
||||
onLoad={handleLoad}
|
||||
className={`w-full h-full object-contain transition-opacity duration-300 ${
|
||||
loaded ? 'opacity-100' : 'opacity-0'
|
||||
}`}
|
||||
/>
|
||||
)}
|
||||
{children && <div className="absolute inset-0 flex items-center justify-center">{children}</div>}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default SampleImageCard;
|
||||
190
ui/src/components/SampleImageModal.tsx
Normal file
@@ -0,0 +1,190 @@
|
||||
'use client';
|
||||
import { useState, useEffect, useMemo } from 'react';
|
||||
import { createGlobalState } from 'react-global-hooks';
|
||||
import { Dialog, DialogBackdrop, DialogPanel } from '@headlessui/react';
|
||||
|
||||
export interface SampleImageModalState {
|
||||
imgPath: string;
|
||||
numSamples: number;
|
||||
sampleImages: string[];
|
||||
}
|
||||
|
||||
export const sampleImageModalState = createGlobalState<SampleImageModalState | null>(null);
|
||||
|
||||
export const openSampleImage = (sampleImageProps: SampleImageModalState) => {
|
||||
sampleImageModalState.set(sampleImageProps);
|
||||
};
|
||||
|
||||
export default function SampleImageModal() {
|
||||
const [imageModal, setImageModal] = sampleImageModalState.use();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (imageModal) {
|
||||
setIsOpen(true);
|
||||
}
|
||||
}, [imageModal]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen) {
|
||||
// use timeout to allow the dialog to close before resetting the state
|
||||
setTimeout(() => {
|
||||
setImageModal(null);
|
||||
}, 500);
|
||||
}
|
||||
}, [isOpen]);
|
||||
|
||||
const onCancel = () => {
|
||||
setIsOpen(false);
|
||||
};
|
||||
|
||||
const imgInfo = useMemo(() => {
|
||||
const ii = {
|
||||
filename: '',
|
||||
step: 0,
|
||||
promptIdx: 0,
|
||||
};
|
||||
if (imageModal?.imgPath) {
|
||||
const filename = imageModal.imgPath.split('/').pop();
|
||||
if (!filename) return ii;
|
||||
// filename is <timestep>__<zero_pad_step>_<prompt_idx>.<ext>
|
||||
ii.filename = filename as string;
|
||||
const parts = filename
|
||||
.split('.')[0]
|
||||
.split('_')
|
||||
.filter(p => p !== '');
|
||||
if (parts.length === 3) {
|
||||
ii.step = parseInt(parts[1]);
|
||||
ii.promptIdx = parseInt(parts[2]);
|
||||
}
|
||||
}
|
||||
return ii;
|
||||
}, [imageModal]);
|
||||
|
||||
const handleArrowUp = () => {
|
||||
if (!imageModal) return;
|
||||
console.log('Arrow Up pressed');
|
||||
// Change image to same sample but up one step
|
||||
const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath);
|
||||
if (currentIdx === -1) return;
|
||||
const nextIdx = currentIdx - imageModal.numSamples;
|
||||
if (nextIdx < 0) return;
|
||||
openSampleImage({
|
||||
imgPath: imageModal.sampleImages[nextIdx],
|
||||
numSamples: imageModal.numSamples,
|
||||
sampleImages: imageModal.sampleImages,
|
||||
});
|
||||
};
|
||||
|
||||
const handleArrowDown = () => {
|
||||
if (!imageModal) return;
|
||||
console.log('Arrow Down pressed');
|
||||
// Change image to same sample but down one step
|
||||
const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath);
|
||||
if (currentIdx === -1) return;
|
||||
const nextIdx = currentIdx + imageModal.numSamples;
|
||||
if (nextIdx >= imageModal.sampleImages.length) return;
|
||||
openSampleImage({
|
||||
imgPath: imageModal.sampleImages[nextIdx],
|
||||
numSamples: imageModal.numSamples,
|
||||
sampleImages: imageModal.sampleImages,
|
||||
});
|
||||
};
|
||||
|
||||
const handleArrowLeft = () => {
|
||||
if (!imageModal) return;
|
||||
if (imgInfo.promptIdx === 0) return;
|
||||
console.log('Arrow Left pressed');
|
||||
// go to previous sample
|
||||
const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath);
|
||||
if (currentIdx === -1) return;
|
||||
const minIdx = currentIdx - imgInfo.promptIdx;
|
||||
const nextIdx = currentIdx - 1;
|
||||
if (nextIdx < minIdx) return;
|
||||
openSampleImage({
|
||||
imgPath: imageModal.sampleImages[nextIdx],
|
||||
numSamples: imageModal.numSamples,
|
||||
sampleImages: imageModal.sampleImages,
|
||||
});
|
||||
};
|
||||
|
||||
const handleArrowRight = () => {
|
||||
if (!imageModal) return;
|
||||
console.log('Arrow Right pressed');
|
||||
// go to next sample
|
||||
const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath);
|
||||
if (currentIdx === -1) return;
|
||||
const stepMinIdx = currentIdx - imgInfo.promptIdx;
|
||||
const maxIdx = stepMinIdx + imageModal.numSamples - 1;
|
||||
const nextIdx = currentIdx + 1;
|
||||
if (nextIdx > maxIdx) return;
|
||||
if (nextIdx >= imageModal.sampleImages.length) return;
|
||||
openSampleImage({
|
||||
imgPath: imageModal.sampleImages[nextIdx],
|
||||
numSamples: imageModal.numSamples,
|
||||
sampleImages: imageModal.sampleImages,
|
||||
});
|
||||
};
|
||||
|
||||
// Handle keyboard events
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (!isOpen) return;
|
||||
|
||||
switch (event.key) {
|
||||
case 'Escape':
|
||||
onCancel();
|
||||
break;
|
||||
case 'ArrowUp':
|
||||
handleArrowUp();
|
||||
break;
|
||||
case 'ArrowDown':
|
||||
handleArrowDown();
|
||||
break;
|
||||
case 'ArrowLeft':
|
||||
handleArrowLeft();
|
||||
break;
|
||||
case 'ArrowRight':
|
||||
handleArrowRight();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('keydown', handleKeyDown);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('keydown', handleKeyDown);
|
||||
};
|
||||
}, [isOpen, imageModal, imgInfo]);
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onClose={onCancel} className="relative z-10">
|
||||
<DialogBackdrop
|
||||
transition
|
||||
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
|
||||
/>
|
||||
|
||||
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
|
||||
<div className="flex min-h-full items-center justify-center p-4 text-center">
|
||||
<DialogPanel
|
||||
transition
|
||||
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in max-w-[95%] max-h-[95vh] data-closed:sm:translate-y-0 data-closed:sm:scale-95"
|
||||
>
|
||||
<div className="flex justify-center items-center">
|
||||
{imageModal?.imgPath && (
|
||||
<img
|
||||
src={`/api/img/${encodeURIComponent(imageModal.imgPath)}`}
|
||||
alt="Sample Image"
|
||||
className="max-w-full max-h-[calc(95vh-2rem)] object-contain"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="bg-gray-950 text-center text-sm p-2">step: {imgInfo.step}</div>
|
||||
</DialogPanel>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
95
ui/src/components/SampleImages.tsx
Normal file
@@ -0,0 +1,95 @@
|
||||
import { useMemo } from 'react';
|
||||
import useSampleImages from '@/hooks/useSampleImages';
|
||||
import SampleImageCard from './SampleImageCard';
|
||||
import { Job } from '@prisma/client';
|
||||
import { JobConfig } from '@/types';
|
||||
|
||||
interface SampleImagesProps {
|
||||
job: Job;
|
||||
}
|
||||
|
||||
export default function SampleImages({ job }: SampleImagesProps) {
|
||||
const { sampleImages, status, refreshSampleImages } = useSampleImages(job.id, 5000);
|
||||
const numSamples = useMemo(() => {
|
||||
if (job?.job_config) {
|
||||
const jobConfig = JSON.parse(job.job_config) as JobConfig;
|
||||
const sampleConfig = jobConfig.config.process[0].sample;
|
||||
return sampleConfig.prompts.length;
|
||||
}
|
||||
return 10;
|
||||
}, [job]);
|
||||
|
||||
// Use direct Tailwind class without string interpolation
|
||||
// This way Tailwind can properly generate the class
|
||||
// I hate this, but it's the only way to make it work
|
||||
const gridColsClass = useMemo(() => {
|
||||
const cols = Math.min(numSamples, 20);
|
||||
|
||||
switch (cols) {
|
||||
case 1:
|
||||
return 'grid-cols-1';
|
||||
case 2:
|
||||
return 'grid-cols-2';
|
||||
case 3:
|
||||
return 'grid-cols-3';
|
||||
case 4:
|
||||
return 'grid-cols-4';
|
||||
case 5:
|
||||
return 'grid-cols-5';
|
||||
case 6:
|
||||
return 'grid-cols-6';
|
||||
case 7:
|
||||
return 'grid-cols-7';
|
||||
case 8:
|
||||
return 'grid-cols-8';
|
||||
case 9:
|
||||
return 'grid-cols-9';
|
||||
case 10:
|
||||
return 'grid-cols-10';
|
||||
case 11:
|
||||
return 'grid-cols-11';
|
||||
case 12:
|
||||
return 'grid-cols-12';
|
||||
case 13:
|
||||
return 'grid-cols-13';
|
||||
case 14:
|
||||
return 'grid-cols-14';
|
||||
case 15:
|
||||
return 'grid-cols-15';
|
||||
case 16:
|
||||
return 'grid-cols-16';
|
||||
case 17:
|
||||
return 'grid-cols-17';
|
||||
case 18:
|
||||
return 'grid-cols-18';
|
||||
case 19:
|
||||
return 'grid-cols-19';
|
||||
case 20:
|
||||
return 'grid-cols-20';
|
||||
default:
|
||||
return 'grid-cols-1';
|
||||
}
|
||||
}, [numSamples]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="pb-4">
|
||||
{status === 'loading' && sampleImages.length === 0 && <p>Loading...</p>}
|
||||
{status === 'error' && <p>Error fetching sample images</p>}
|
||||
{sampleImages && (
|
||||
<div className={`grid ${gridColsClass} gap-1`}>
|
||||
{sampleImages.map((sample: string) => (
|
||||
<SampleImageCard
|
||||
key={sample}
|
||||
imageUrl={sample}
|
||||
numSamples={numSamples}
|
||||
sampleImages={sampleImages}
|
||||
alt="Sample Image"
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
65
ui/src/components/Sidebar.tsx
Normal file
@@ -0,0 +1,65 @@
|
||||
import Link from 'next/link';
|
||||
import { Home, Settings, BrainCircuit, Images } from 'lucide-react';
|
||||
|
||||
const Sidebar = () => {
|
||||
const navigation = [
|
||||
{ name: 'Dashboard', href: '/dashboard', icon: Home },
|
||||
{ name: 'Training Jobs', href: '/jobs', icon: BrainCircuit },
|
||||
{ name: 'Datasets', href: '/datasets', icon: Images },
|
||||
{ name: 'Settings', href: '/settings', icon: Settings },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-64 bg-gray-900 text-gray-100">
|
||||
<div className="p-4">
|
||||
<h1 className="text-xl">
|
||||
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-8 mr-3 inline" />
|
||||
Ostris - AI Toolkit
|
||||
</h1>
|
||||
</div>
|
||||
<nav className="flex-1">
|
||||
<ul className="px-2 py-4 space-y-2">
|
||||
{navigation.map(item => (
|
||||
<li key={item.name}>
|
||||
<Link
|
||||
href={item.href}
|
||||
className="flex items-center px-4 py-2 text-gray-300 hover:bg-gray-800 rounded-lg transition-colors"
|
||||
>
|
||||
<item.icon className="w-5 h-5 mr-3" />
|
||||
{item.name}
|
||||
</Link>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</nav>
|
||||
<a href="https://patreon.com/ostris" target="_blank" rel="noreferrer" className="flex items-center space-x-2 p-4">
|
||||
<div className='min-w-[26px] min-h-[26px]'>
|
||||
<svg
|
||||
viewBox="0 0 512 512"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fillRule="evenodd"
|
||||
clipRule="evenodd"
|
||||
strokeLinejoin="round"
|
||||
strokeMiterlimit="2"
|
||||
>
|
||||
<g transform="matrix(.47407 0 0 .47407 .383 .422)">
|
||||
<clipPath id="prefix__a">
|
||||
<path d="M0 0h1080v1080H0z"></path>
|
||||
</clipPath>
|
||||
<g clipPath="url(#prefix__a)">
|
||||
<path
|
||||
d="M1033.05 324.45c-.19-137.9-107.59-250.92-233.6-291.7-156.48-50.64-362.86-43.3-512.28 27.2-181.1 85.46-237.99 272.66-240.11 459.36-1.74 153.5 13.58 557.79 241.62 560.67 169.44 2.15 194.67-216.18 273.07-321.33 55.78-74.81 127.6-95.94 216.01-117.82 151.95-37.61 255.51-157.53 255.29-316.38z"
|
||||
fillRule="nonzero"
|
||||
fill="#ffffff"
|
||||
></path>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
</div>
|
||||
<div className="text-gray-500 text-md mb-2 flex-1 pt-2 pl-2">Support me on Patreon</div>
|
||||
</a>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Sidebar;
|
||||
11
ui/src/components/ThemeProvider.tsx
Normal file
@@ -0,0 +1,11 @@
|
||||
'use client';
|
||||
|
||||
import { createContext, useContext, useEffect, useState } from 'react';
|
||||
|
||||
const ThemeContext = createContext({ isDark: true });
|
||||
|
||||
export const ThemeProvider = ({ children }: { children: React.ReactNode }) => {
|
||||
const [isDark, setIsDark] = useState(true);
|
||||
|
||||
return <ThemeContext.Provider value={{ isDark }}>{children}</ThemeContext.Provider>;
|
||||
};
|
||||
72
ui/src/components/UniversalTable.tsx
Normal file
@@ -0,0 +1,72 @@
|
||||
import Loading from './Loading';
|
||||
import classNames from 'classnames';
|
||||
|
||||
export interface TableColumn {
|
||||
title: string;
|
||||
key: string;
|
||||
render?: (row: any) => React.ReactNode;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
interface TableRow {
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
interface TableProps {
|
||||
columns: TableColumn[];
|
||||
rows: TableRow[];
|
||||
isLoading: boolean;
|
||||
onRefresh: () => void;
|
||||
}
|
||||
|
||||
export default function UniversalTable({ columns, rows, isLoading, onRefresh = () => {} }: TableProps) {
|
||||
return (
|
||||
<div className="w-full bg-gray-900 rounded-md shadow-md">
|
||||
{isLoading ? (
|
||||
<div className="p-4 flex justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
) : rows.length === 0 ? (
|
||||
<div className="p-6 text-center text-gray-400">
|
||||
<p className="text-sm">Empty</p>
|
||||
<button
|
||||
onClick={() => onRefresh()}
|
||||
className="mt-2 px-3 py-1 text-xs bg-gray-800 hover:bg-gray-700 text-gray-300 rounded transition-colors"
|
||||
>
|
||||
Refresh
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-sm text-left text-gray-300">
|
||||
<thead className="text-xs uppercase bg-gray-800 text-gray-400">
|
||||
<tr>
|
||||
{columns.map(column => (
|
||||
<th key={column.key} className="px-3 py-2">
|
||||
{column.title}
|
||||
</th>
|
||||
))}
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{rows?.map((row, index) => {
|
||||
// Style for alternating rows
|
||||
const rowClass = index % 2 === 0 ? 'bg-gray-900' : 'bg-gray-800';
|
||||
|
||||
return (
|
||||
<tr key={index} className={`${rowClass} border-b border-gray-700 hover:bg-gray-700`}>
|
||||
{columns.map(column => (
|
||||
<td key={column.key} className={classNames('px-3 py-2', column.className)}>
|
||||
{column.render ? column.render(row) : row[column.key]}
|
||||
</td>
|
||||
))}
|
||||
</tr>
|
||||
);
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
194
ui/src/components/formInputs.tsx
Normal file
@@ -0,0 +1,194 @@
|
||||
import React from 'react';
|
||||
import classNames from 'classnames';
|
||||
|
||||
const labelClasses = 'block text-xs mb-1 mt-2 text-gray-300';
|
||||
const inputClasses =
|
||||
'w-full text-sm px-3 py-1 bg-gray-800 border border-gray-700 rounded-sm focus:ring-2 focus:ring-gray-600 focus:border-transparent';
|
||||
|
||||
export interface InputProps {
|
||||
label?: string;
|
||||
className?: string;
|
||||
placeholder?: string;
|
||||
required?: boolean;
|
||||
}
|
||||
|
||||
export interface TextInputProps extends InputProps {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
type?: 'text' | 'password';
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export const TextInput = (props: TextInputProps) => {
|
||||
const { label, value, onChange, placeholder, required, disabled } = props;
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<input
|
||||
type={props.type || 'text'}
|
||||
value={value}
|
||||
onChange={e => {
|
||||
if (disabled) return;
|
||||
onChange(e.target.value);
|
||||
}}
|
||||
className={`${inputClasses} ${disabled && 'opacity-30 cursor-not-allowed'}`}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export interface NumberInputProps extends InputProps {
|
||||
value: number;
|
||||
onChange: (value: number) => void;
|
||||
min?: number;
|
||||
max?: number;
|
||||
}
|
||||
|
||||
export const NumberInput = (props: NumberInputProps) => {
|
||||
const { label, value, onChange, placeholder, required, min, max } = props;
|
||||
|
||||
// Add controlled internal state to properly handle partial inputs
|
||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||
|
||||
// Sync internal state with prop value
|
||||
React.useEffect(() => {
|
||||
setInputValue(value ?? '');
|
||||
}, [value]);
|
||||
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<input
|
||||
type="number"
|
||||
value={inputValue}
|
||||
onChange={e => {
|
||||
const rawValue = e.target.value;
|
||||
|
||||
// Update the input display with the raw value
|
||||
setInputValue(rawValue);
|
||||
|
||||
// Handle empty or partial inputs
|
||||
if (rawValue === '' || rawValue === '-') {
|
||||
// For empty or partial negative input, don't call onChange yet
|
||||
return;
|
||||
}
|
||||
|
||||
const numValue = Number(rawValue);
|
||||
|
||||
// Only apply constraints and call onChange when we have a valid number
|
||||
if (!isNaN(numValue)) {
|
||||
let constrainedValue = numValue;
|
||||
|
||||
// Apply min/max constraints if they exist
|
||||
if (min !== undefined && constrainedValue < min) {
|
||||
constrainedValue = min;
|
||||
}
|
||||
if (max !== undefined && constrainedValue > max) {
|
||||
constrainedValue = max;
|
||||
}
|
||||
|
||||
onChange(constrainedValue);
|
||||
}
|
||||
}}
|
||||
className={inputClasses}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
min={min}
|
||||
max={max}
|
||||
step="any"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export interface SelectInputProps extends InputProps {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
options: { value: string; label: string }[];
|
||||
}
|
||||
|
||||
export const SelectInput = (props: SelectInputProps) => {
|
||||
const { label, value, onChange, options } = props;
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<select value={value} onChange={e => onChange(e.target.value)} className={inputClasses}>
|
||||
{options.map(option => (
|
||||
<option key={option.value} value={option.value}>
|
||||
{option.label}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export interface CheckboxProps {
|
||||
label?: string;
|
||||
checked: boolean;
|
||||
onChange: (checked: boolean) => void;
|
||||
className?: string;
|
||||
required?: boolean;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export const Checkbox = (props: CheckboxProps) => {
|
||||
const { label, checked, onChange, required, disabled } = props;
|
||||
const id = React.useId();
|
||||
|
||||
return (
|
||||
<div className={classNames('flex items-center gap-3', props.className)}>
|
||||
<button
|
||||
type="button"
|
||||
role="switch"
|
||||
id={id}
|
||||
aria-checked={checked}
|
||||
aria-required={required}
|
||||
disabled={disabled}
|
||||
onClick={() => !disabled && onChange(!checked)}
|
||||
className={classNames(
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
|
||||
checked ? 'bg-blue-600' : 'bg-gray-700',
|
||||
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80'
|
||||
)}
|
||||
>
|
||||
<span className="sr-only">Toggle {label}</span>
|
||||
<span
|
||||
className={classNames(
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
checked ? 'translate-x-5' : 'translate-x-0'
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
{label && (
|
||||
<label
|
||||
htmlFor={id}
|
||||
className={classNames(
|
||||
'text-sm font-medium cursor-pointer select-none',
|
||||
disabled ? 'text-gray-500' : 'text-gray-300'
|
||||
)}
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface FormGroupProps {
|
||||
label?: string;
|
||||
className?: string;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children }) => {
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<div className="px-4 space-y-2">{children}</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
27
ui/src/components/layout.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
import classNames from 'classnames';
|
||||
|
||||
interface Props {
|
||||
className?: string;
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
|
||||
export const TopBar: React.FC<Props> = ({ children, className }) => {
|
||||
return (
|
||||
<div
|
||||
className={classNames(
|
||||
'absolute top-0 left-0 w-full h-12 dark:bg-gray-900 shadow-sm z-10 flex items-center px-2',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{children ? children : null}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const MainContent: React.FC<Props> = ({ children, className }) => {
|
||||
return (
|
||||
<div className={classNames('pt-14 px-4 absolute top-0 left-0 w-full h-full overflow-auto', className)}>
|
||||
{children ? children : null}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
30
ui/src/hooks/useDatasetList.tsx
Normal file
@@ -0,0 +1,30 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
export default function useDatasetList() {
|
||||
const [datasets, setDatasets] = useState<string[]>([]);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshDatasets = () => {
|
||||
setStatus('loading');
|
||||
fetch('/api/datasets/list')
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Datasets:', data);
|
||||
// sort
|
||||
data.sort((a: string, b: string) => a.localeCompare(b));
|
||||
setDatasets(data);
|
||||
setStatus('success');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
useEffect(() => {
|
||||
refreshDatasets();
|
||||
}, []);
|
||||
|
||||
return { datasets, setDatasets, status, refreshDatasets };
|
||||
}
|
||||
52
ui/src/hooks/useFilesList.tsx
Normal file
@@ -0,0 +1,52 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState, useRef } from 'react';
|
||||
|
||||
interface FileObject {
|
||||
path: string;
|
||||
size: number;
|
||||
}
|
||||
|
||||
export default function useFilesList(jobID: string, reloadInterval: null | number = null) {
|
||||
const [files, setFiles] = useState<FileObject[]>([]);
|
||||
const didInitialLoadRef = useRef(false);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle');
|
||||
|
||||
const refreshFiles = () => {
|
||||
let loadStatus: 'loading' | 'refreshing' = 'loading';
|
||||
if (didInitialLoadRef.current) {
|
||||
loadStatus = 'refreshing';
|
||||
}
|
||||
setStatus(loadStatus);
|
||||
fetch(`/api/jobs/${jobID}/files`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Fetched files:', data);
|
||||
if (data.files) {
|
||||
setFiles(data.files);
|
||||
}
|
||||
setStatus('success');
|
||||
didInitialLoadRef.current = true;
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refreshFiles();
|
||||
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
refreshFiles();
|
||||
}, reloadInterval);
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
};
|
||||
}
|
||||
}, [jobID]);
|
||||
|
||||
return { files, setFiles, status, refreshFiles };
|
||||
}
|
||||
54
ui/src/hooks/useGPUInfo.tsx
Normal file
@@ -0,0 +1,54 @@
|
||||
'use client';
|
||||
|
||||
import { GPUApiResponse, GpuInfo } from '@/types';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterval: null | number = null) {
|
||||
const [gpuList, setGpuList] = useState<GpuInfo[]>([]);
|
||||
const [isGPUInfoLoaded, setIsLoaded] = useState(false);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const fetchGpuInfo = async () => {
|
||||
setStatus('loading');
|
||||
try {
|
||||
const response = await fetch('/api/gpu');
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data: GPUApiResponse = await response.json();
|
||||
let gpus = data.gpus.sort((a, b) => a.index - b.index);
|
||||
if (gpuIds) {
|
||||
gpus = gpus.filter(gpu => gpuIds.includes(gpu.index));
|
||||
}
|
||||
|
||||
setGpuList(gpus);
|
||||
setStatus('success');
|
||||
} catch (err) {
|
||||
console.error(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||
setStatus('error');
|
||||
} finally {
|
||||
setIsLoaded(true);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// Fetch immediately on component mount
|
||||
fetchGpuInfo();
|
||||
|
||||
// Set up interval if specified
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
fetchGpuInfo();
|
||||
}, reloadInterval);
|
||||
|
||||
// Cleanup interval on unmount
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
};
|
||||
}
|
||||
}, [gpuIds, reloadInterval]); // Added dependencies
|
||||
|
||||
return { gpuList, setGpuList, isGPUInfoLoaded, status, refreshGpuInfo: fetchGpuInfo };
|
||||
}
|
||||
40
ui/src/hooks/useJob.tsx
Normal file
@@ -0,0 +1,40 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
export default function useJob(jobID: string, reloadInterval: null | number = null) {
|
||||
const [job, setJob] = useState<Job | null>(null);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshJob = () => {
|
||||
setStatus('loading');
|
||||
fetch(`/api/jobs?id=${jobID}`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Job:', data);
|
||||
setJob(data);
|
||||
setStatus('success');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refreshJob();
|
||||
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
refreshJob();
|
||||
}, reloadInterval);
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
}
|
||||
}
|
||||
}, [jobID]);
|
||||
|
||||
return { job, setJob, status, refreshJob };
|
||||
}
|
||||
37
ui/src/hooks/useJobsList.tsx
Normal file
@@ -0,0 +1,37 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
export default function useJobsList(onlyActive = false) {
|
||||
const [jobs, setJobs] = useState<Job[]>([]);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshJobs = () => {
|
||||
setStatus('loading');
|
||||
fetch('/api/jobs')
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Jobs:', data);
|
||||
if (data.error) {
|
||||
console.log('Error fetching jobs:', data.error);
|
||||
setStatus('error');
|
||||
} else {
|
||||
if (onlyActive) {
|
||||
data.jobs = data.jobs.filter((job: Job) => job.status === 'running');
|
||||
}
|
||||
setJobs(data.jobs);
|
||||
setStatus('success');
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
useEffect(() => {
|
||||
refreshJobs();
|
||||
}, []);
|
||||
|
||||
return { jobs, setJobs, status, refreshJobs };
|
||||
}
|
||||
40
ui/src/hooks/useSampleImages.tsx
Normal file
@@ -0,0 +1,40 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
|
||||
const [sampleImages, setSampleImages] = useState<string[]>([]);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshSampleImages = () => {
|
||||
setStatus('loading');
|
||||
fetch(`/api/jobs/${jobID}/samples`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
if (data.samples) {
|
||||
setSampleImages(data.samples);
|
||||
}
|
||||
setStatus('success');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refreshSampleImages();
|
||||
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
refreshSampleImages();
|
||||
}, reloadInterval);
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
};
|
||||
}
|
||||
}, [jobID]);
|
||||
|
||||
return { sampleImages, setSampleImages, status, refreshSampleImages };
|
||||
}
|
||||
28
ui/src/hooks/useSettings.tsx
Normal file
@@ -0,0 +1,28 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
export default function useSettings() {
|
||||
const [settings, setSettings] = useState({
|
||||
HF_TOKEN: '',
|
||||
TRAINING_FOLDER: '',
|
||||
DATASETS_FOLDER: '',
|
||||
});
|
||||
const [isSettingsLoaded, setIsLoaded] = useState(false);
|
||||
useEffect(() => {
|
||||
// Fetch current settings
|
||||
fetch('/api/settings')
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
setSettings({
|
||||
HF_TOKEN: data.HF_TOKEN || '',
|
||||
TRAINING_FOLDER: data.TRAINING_FOLDER || '',
|
||||
DATASETS_FOLDER: data.DATASETS_FOLDER || '',
|
||||
});
|
||||
setIsLoaded(true);
|
||||
})
|
||||
.catch(error => console.error('Error fetching settings:', error));
|
||||
}, []);
|
||||
|
||||
return { settings, setSettings, isSettingsLoaded };
|
||||
}
|
||||
4
ui/src/paths.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
import path from 'path';
|
||||
export const TOOLKIT_ROOT = path.resolve('@', '..', '..');
|
||||
export const defaultTrainFolder = path.join(TOOLKIT_ROOT, 'output');
|
||||
export const defaultDatasetsFolder = path.join(TOOLKIT_ROOT, 'datasets');
|
||||
68
ui/src/server/settings.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { defaultDatasetsFolder } from '@/paths';
|
||||
import { defaultTrainFolder } from '@/paths';
|
||||
import NodeCache from 'node-cache';
|
||||
|
||||
const myCache = new NodeCache();
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export const flushCache = () => {
|
||||
myCache.flushAll();
|
||||
};
|
||||
|
||||
export const getDatasetsRoot = async () => {
|
||||
const key = 'DATASETS_FOLDER';
|
||||
let datasetsPath = myCache.get(key) as string;
|
||||
if (datasetsPath) {
|
||||
return datasetsPath;
|
||||
}
|
||||
let row = await prisma.settings.findFirst({
|
||||
where: {
|
||||
key: 'DATASETS_FOLDER',
|
||||
},
|
||||
});
|
||||
datasetsPath = defaultDatasetsFolder;
|
||||
if (row?.value && row.value !== '') {
|
||||
datasetsPath = row.value;
|
||||
}
|
||||
myCache.set(key, datasetsPath);
|
||||
return datasetsPath as string;
|
||||
};
|
||||
|
||||
export const getTrainingFolder = async () => {
|
||||
const key = 'TRAINING_FOLDER';
|
||||
let trainingRoot = myCache.get(key) as string;
|
||||
if (trainingRoot) {
|
||||
return trainingRoot;
|
||||
}
|
||||
let row = await prisma.settings.findFirst({
|
||||
where: {
|
||||
key: key,
|
||||
},
|
||||
});
|
||||
trainingRoot = defaultTrainFolder;
|
||||
if (row?.value && row.value !== '') {
|
||||
trainingRoot = row.value;
|
||||
}
|
||||
myCache.set(key, trainingRoot);
|
||||
return trainingRoot as string;
|
||||
};
|
||||
|
||||
export const getHFToken = async () => {
|
||||
const key = 'HF_TOKEN';
|
||||
let token = myCache.get(key) as string;
|
||||
if (token) {
|
||||
return token;
|
||||
}
|
||||
let row = await prisma.settings.findFirst({
|
||||
where: {
|
||||
key: key,
|
||||
},
|
||||
});
|
||||
token = '';
|
||||
if (row?.value && row.value !== '') {
|
||||
token = row.value;
|
||||
}
|
||||
myCache.set(key, token);
|
||||
return token;
|
||||
};
|
||||
160
ui/src/types.ts
Normal file
@@ -0,0 +1,160 @@
|
||||
/**
|
||||
* GPU API response
|
||||
*/
|
||||
|
||||
export interface GpuUtilization {
|
||||
gpu: number;
|
||||
memory: number;
|
||||
}
|
||||
|
||||
export interface GpuMemory {
|
||||
total: number;
|
||||
free: number;
|
||||
used: number;
|
||||
}
|
||||
|
||||
export interface GpuPower {
|
||||
draw: number;
|
||||
limit: number;
|
||||
}
|
||||
|
||||
export interface GpuClocks {
|
||||
graphics: number;
|
||||
memory: number;
|
||||
}
|
||||
|
||||
export interface GpuFan {
|
||||
speed: number;
|
||||
}
|
||||
|
||||
export interface GpuInfo {
|
||||
index: number;
|
||||
name: string;
|
||||
driverVersion: string;
|
||||
temperature: number;
|
||||
utilization: GpuUtilization;
|
||||
memory: GpuMemory;
|
||||
power: GpuPower;
|
||||
clocks: GpuClocks;
|
||||
fan: GpuFan;
|
||||
}
|
||||
|
||||
export interface GPUApiResponse {
|
||||
hasNvidiaSmi: boolean;
|
||||
gpus: GpuInfo[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Training configuration
|
||||
*/
|
||||
|
||||
export interface NetworkConfig {
|
||||
type: 'lora';
|
||||
linear: number;
|
||||
linear_alpha: number;
|
||||
}
|
||||
|
||||
export interface SaveConfig {
|
||||
dtype: string;
|
||||
save_every: number;
|
||||
max_step_saves_to_keep: number;
|
||||
save_format: string;
|
||||
push_to_hub: boolean;
|
||||
}
|
||||
|
||||
export interface DatasetConfig {
|
||||
folder_path: string;
|
||||
mask_path: string | null;
|
||||
mask_min_value: number;
|
||||
default_caption: string;
|
||||
caption_ext: string;
|
||||
caption_dropout_rate: number;
|
||||
shuffle_tokens?: boolean;
|
||||
is_reg: boolean;
|
||||
network_weight: number;
|
||||
cache_latents_to_disk?: boolean;
|
||||
resolution: number[];
|
||||
}
|
||||
|
||||
export interface EMAConfig {
|
||||
use_ema: boolean;
|
||||
ema_decay: number;
|
||||
}
|
||||
|
||||
export interface TrainConfig {
|
||||
batch_size: number;
|
||||
bypass_guidance_embedding?: boolean;
|
||||
steps: number;
|
||||
gradient_accumulation: number;
|
||||
train_unet: boolean;
|
||||
train_text_encoder: boolean;
|
||||
gradient_checkpointing: boolean;
|
||||
noise_scheduler: string;
|
||||
timestep_type: string;
|
||||
content_or_style: string;
|
||||
optimizer: string;
|
||||
lr: number;
|
||||
ema_config?: EMAConfig;
|
||||
dtype: string;
|
||||
optimizer_params: {
|
||||
weight_decay: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface QuantizeKwargsConfig {
|
||||
exclude: string[];
|
||||
}
|
||||
|
||||
export interface ModelConfig {
|
||||
name_or_path: string;
|
||||
is_flux?: boolean;
|
||||
is_lumina2?: boolean;
|
||||
quantize: boolean;
|
||||
quantize_te: boolean;
|
||||
quantize_kwargs?: QuantizeKwargsConfig;
|
||||
}
|
||||
|
||||
export interface SampleConfig {
|
||||
sampler: string;
|
||||
sample_every: number;
|
||||
width: number;
|
||||
height: number;
|
||||
prompts: string[];
|
||||
neg: string;
|
||||
seed: number;
|
||||
walk_seed: boolean;
|
||||
guidance_scale: number;
|
||||
sample_steps: number;
|
||||
}
|
||||
|
||||
export interface ProcessConfig {
|
||||
type: 'ui_trainer';
|
||||
sqlite_db_path?: string;
|
||||
training_folder: string;
|
||||
performance_log_every: number;
|
||||
trigger_word: string | null;
|
||||
device: string;
|
||||
network?: NetworkConfig;
|
||||
save: SaveConfig;
|
||||
datasets: DatasetConfig[];
|
||||
train: TrainConfig;
|
||||
model: ModelConfig;
|
||||
sample: SampleConfig;
|
||||
}
|
||||
|
||||
export interface ConfigObject {
|
||||
name: string;
|
||||
process: ProcessConfig[];
|
||||
}
|
||||
|
||||
export interface MetaConfig {
|
||||
name: string;
|
||||
version: string;
|
||||
}
|
||||
|
||||
export interface JobConfig {
|
||||
job: string;
|
||||
config: ConfigObject;
|
||||
meta: MetaConfig;
|
||||
}
|
||||
4
ui/src/utils/basic.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export const objectCopy = <T>(obj: T): T => {
|
||||
return JSON.parse(JSON.stringify(obj)) as T;
|
||||
};
|
||||
|
||||
88
ui/src/utils/hooks.tsx
Normal file
@@ -0,0 +1,88 @@
|
||||
import React from 'react';
|
||||
|
||||
/**
|
||||
* Updates a deeply nested value in an object using a string path
|
||||
* @param obj The object to update
|
||||
* @param value The new value to set
|
||||
* @param path String path to the property (e.g. 'config.process[0].model.name_or_path')
|
||||
* @returns A new object with the updated value
|
||||
*/
|
||||
export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
|
||||
// Create a copy of the original object to maintain immutability
|
||||
const result = { ...obj };
|
||||
|
||||
// if path is not provided, be root path
|
||||
if (!path) {
|
||||
path = '';
|
||||
}
|
||||
|
||||
// Split the path into segments
|
||||
const pathArray = path.split('.').flatMap(segment => {
|
||||
// Handle array notation like 'process[0]'
|
||||
const arrayMatch = segment.match(/^([^\[]+)(\[\d+\])+/);
|
||||
if (arrayMatch) {
|
||||
const propName = arrayMatch[1];
|
||||
const indices = segment
|
||||
.substring(propName.length)
|
||||
.match(/\[(\d+)\]/g)
|
||||
?.map(idx => parseInt(idx.substring(1, idx.length - 1)));
|
||||
|
||||
// Return property name followed by array indices
|
||||
return [propName, ...(indices || [])];
|
||||
}
|
||||
return segment;
|
||||
});
|
||||
|
||||
// Navigate to the target location
|
||||
let current: any = result;
|
||||
for (let i = 0; i < pathArray.length - 1; i++) {
|
||||
const key = pathArray[i];
|
||||
|
||||
// If current key is a number, treat it as an array index
|
||||
if (typeof key === 'number') {
|
||||
if (!Array.isArray(current)) {
|
||||
throw new Error(`Cannot access index ${key} of non-array`);
|
||||
}
|
||||
// Create a copy of the array to maintain immutability
|
||||
current = [...current];
|
||||
} else {
|
||||
// For object properties, create a new object if it doesn't exist
|
||||
if (current[key] === undefined) {
|
||||
// Check if the next key is a number, if so create an array, otherwise an object
|
||||
const nextKey = pathArray[i + 1];
|
||||
current[key] = typeof nextKey === 'number' ? [] : {};
|
||||
} else {
|
||||
// Create a shallow copy to maintain immutability
|
||||
current[key] = Array.isArray(current[key]) ? [...current[key]] : { ...current[key] };
|
||||
}
|
||||
}
|
||||
|
||||
// Move to the next level
|
||||
current = current[key];
|
||||
}
|
||||
|
||||
// Set the value at the final path segment
|
||||
const finalKey = pathArray[pathArray.length - 1];
|
||||
current[finalKey] = value;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook for managing a complex state object with string path updates
|
||||
* @param initialState The initial state object
|
||||
* @returns [state, setValue] tuple
|
||||
*/
|
||||
export function useNestedState<T>(initialState: T): [T, (value: any, path?: string) => void] {
|
||||
const [state, setState] = React.useState<T>(initialState);
|
||||
|
||||
const setValue = React.useCallback((value: any, path?: string) => {
|
||||
if (path === undefined) {
|
||||
setState(value);
|
||||
return
|
||||
}
|
||||
setState(prevState => setNestedValue(prevState, value, path));
|
||||
}, []);
|
||||
|
||||
return [state, setValue];
|
||||
}
|
||||
75
ui/src/utils/jobs.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { JobConfig } from '@/types';
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
export const startJob = (jobID: string) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
fetch(`/api/jobs/${jobID}/start`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Job started:', data);
|
||||
resolve();
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error starting job:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
export const stopJob = (jobID: string) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
fetch(`/api/jobs/${jobID}/stop`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Job stopped:', data);
|
||||
resolve();
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error stopping job:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
export const deleteJob = (jobID: string) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
fetch(`/api/jobs/${jobID}/delete`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Job deleted:', data);
|
||||
resolve();
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error deleting job:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
export const getJobConfig = (job: Job) => {
|
||||
return JSON.parse(job.job_config) as JobConfig;
|
||||
};
|
||||
|
||||
export const getAvaliableJobActions = (job: Job) => {
|
||||
const jobConfig = getJobConfig(job);
|
||||
const isStopping = job.stop && job.status === 'running';
|
||||
const canDelete = ['completed', 'stopped', 'error'].includes(job.status) && !isStopping;
|
||||
const canEdit = ['completed', 'stopped', 'error'].includes(job.status) && !isStopping;
|
||||
const canStop = job.status === 'running' && !isStopping;
|
||||
let canStart = ['stopped', 'error'].includes(job.status) && !isStopping;
|
||||
// can resume if more steps were added
|
||||
if (job.status === 'completed' && jobConfig.config.process[0].train.steps > job.step && !isStopping) {
|
||||
canStart = true;
|
||||
}
|
||||
return { canDelete, canEdit, canStop, canStart };
|
||||
};
|
||||
|
||||
export const getNumberOfSamples = (job: Job) => {
|
||||
const jobConfig = getJobConfig(job);
|
||||
return jobConfig.config.process[0].sample?.prompts?.length || 0;
|
||||
}
|
||||
|
||||
export const getTotalSteps = (job: Job) => {
|
||||
const jobConfig = getJobConfig(job);
|
||||
return jobConfig.config.process[0].train.steps;
|
||||
}
|
||||
31
ui/tailwind.config.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
import type { Config } from "tailwindcss";
|
||||
|
||||
const config: Config = {
|
||||
content: [
|
||||
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
],
|
||||
darkMode: "class",
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
gray: {
|
||||
950: "#0a0a0a",
|
||||
900: "#171717",
|
||||
800: "#262626",
|
||||
700: "#404040",
|
||||
600: "#525252",
|
||||
500: "#737373",
|
||||
400: "#a3a3a3",
|
||||
300: "#d4d4d4",
|
||||
200: "#e5e5e5",
|
||||
100: "#f5f5f5",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [],
|
||||
};
|
||||
|
||||
export default config;
|
||||
27
ui/tsconfig.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2017",
|
||||
"lib": ["dom", "dom.iterable", "esnext"],
|
||||
"allowJs": true,
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
"noEmit": true,
|
||||
"esModuleInterop": true,
|
||||
"module": "esnext",
|
||||
"moduleResolution": "bundler",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "preserve",
|
||||
"incremental": true,
|
||||
"plugins": [
|
||||
{
|
||||
"name": "next"
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"@/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
|
||||
"exclude": ["node_modules"]
|
||||
}
|
||||