mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Initial commit
This commit is contained in:
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
import io
|
||||
import struct
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
@@ -425,43 +426,63 @@ def main(argv=None):
|
||||
|
||||
|
||||
is_window_shown = False
|
||||
display_lock = threading.Lock()
|
||||
current_img = None
|
||||
update_event = threading.Event()
|
||||
|
||||
def update_image(img, name):
|
||||
global current_img
|
||||
with display_lock:
|
||||
current_img = (img, name)
|
||||
update_event.set()
|
||||
|
||||
def display_image_in_thread():
|
||||
global is_window_shown
|
||||
|
||||
def display_img():
|
||||
global current_img
|
||||
while True:
|
||||
update_event.wait()
|
||||
with display_lock:
|
||||
if current_img:
|
||||
img, name = current_img
|
||||
cv2.imshow(name, img)
|
||||
current_img = None
|
||||
update_event.clear()
|
||||
if cv2.waitKey(1) & 0xFF == 27: # Esc key to stop
|
||||
cv2.destroyAllWindows()
|
||||
print('\nESC pressed, stopping')
|
||||
break
|
||||
|
||||
if not is_window_shown:
|
||||
is_window_shown = True
|
||||
threading.Thread(target=display_img, daemon=True).start()
|
||||
|
||||
|
||||
def show_img(img, name='AI Toolkit'):
|
||||
global is_window_shown
|
||||
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
cv2.imshow(name, img[:, :, ::-1])
|
||||
k = cv2.waitKey(1) & 0xFF
|
||||
if k == 27: # Esc key to stop
|
||||
print('\nESC pressed, stopping')
|
||||
raise KeyboardInterrupt
|
||||
update_image(img[:, :, ::-1], name)
|
||||
if not is_window_shown:
|
||||
is_window_shown = True
|
||||
|
||||
display_image_in_thread()
|
||||
|
||||
|
||||
def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
|
||||
# if rank is 4
|
||||
if len(imgs.shape) == 4:
|
||||
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
|
||||
else:
|
||||
img_list = [imgs]
|
||||
# put images side by side
|
||||
|
||||
img = torch.cat(img_list, dim=3)
|
||||
# img is -1 to 1, convert to 0 to 255
|
||||
img = img / 2 + 0.5
|
||||
img_numpy = img.to(torch.float32).detach().cpu().numpy()
|
||||
img_numpy = np.clip(img_numpy, 0, 1) * 255
|
||||
# convert to numpy Move channel to last
|
||||
img_numpy = img_numpy.transpose(0, 2, 3, 1)
|
||||
# convert to uint8
|
||||
img_numpy = img_numpy.astype(np.uint8)
|
||||
|
||||
show_img(img_numpy[0], name=name)
|
||||
|
||||
|
||||
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
|
||||
# decode latents
|
||||
if vae.device == 'cpu':
|
||||
vae.to(latents.device)
|
||||
latents = latents / vae.config['scaling_factor']
|
||||
@@ -469,7 +490,6 @@ def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit
|
||||
show_tensors(imgs, name=name)
|
||||
|
||||
|
||||
|
||||
def on_exit():
|
||||
if is_window_shown:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
Reference in New Issue
Block a user