mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Bugfixes and cleanup
This commit is contained in:
@@ -807,10 +807,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
with self.timer('prepare_latents'):
|
with self.timer('prepare_latents'):
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
imgs = None
|
imgs = None
|
||||||
|
is_reg = any(batch.get_is_reg_list())
|
||||||
if batch.tensor is not None:
|
if batch.tensor is not None:
|
||||||
imgs = batch.tensor
|
imgs = batch.tensor
|
||||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||||
if self.train_config.img_multiplier is not None:
|
# dont adjust for regs.
|
||||||
|
if self.train_config.img_multiplier is not None and not is_reg:
|
||||||
# do it ad contrast
|
# do it ad contrast
|
||||||
imgs = reduce_contrast(imgs, self.train_config.img_multiplier)
|
imgs = reduce_contrast(imgs, self.train_config.img_multiplier)
|
||||||
if batch.latents is not None:
|
if batch.latents is not None:
|
||||||
@@ -1495,8 +1497,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||||
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||||
optimizer.load_state_dict(optimizer_state_dict)
|
optimizer.load_state_dict(optimizer_state_dict)
|
||||||
|
del optimizer_state_dict
|
||||||
|
flush()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||||
print(e)
|
print(e)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import sys
|
|||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
import random
|
import random
|
||||||
|
from transformers import CLIPImageProcessor
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||||
@@ -37,12 +38,25 @@ resolution = 512
|
|||||||
bucket_tolerance = 64
|
bucket_tolerance = 64
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
|
clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
||||||
|
|
||||||
|
class FakeAdapter:
|
||||||
|
def __init__(self):
|
||||||
|
self.clip_image_processor = clip_processor
|
||||||
|
|
||||||
|
|
||||||
|
## make fake sd
|
||||||
|
class FakeSD:
|
||||||
|
def __init__(self):
|
||||||
|
self.adapter = FakeAdapter()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##
|
|
||||||
|
|
||||||
dataset_config = DatasetConfig(
|
dataset_config = DatasetConfig(
|
||||||
dataset_path=dataset_folder,
|
dataset_path=dataset_folder,
|
||||||
control_path=dataset_folder,
|
clip_image_path=dataset_folder,
|
||||||
|
square_crop=True,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
# caption_ext='json',
|
# caption_ext='json',
|
||||||
default_caption='default',
|
default_caption='default',
|
||||||
@@ -61,123 +75,7 @@ dataset_config = DatasetConfig(
|
|||||||
# ]
|
# ]
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
|
||||||
|
|
||||||
def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5):
|
|
||||||
if random.random() < p:
|
|
||||||
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
|
||||||
# make sure it is odd
|
|
||||||
if kernel_size % 2 == 0:
|
|
||||||
kernel_size += 1
|
|
||||||
img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size)
|
|
||||||
return img
|
|
||||||
|
|
||||||
def quantize(image, palette):
|
|
||||||
"""
|
|
||||||
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
|
|
||||||
Only works for one image i.e. CHW. Does NOT work for batches.
|
|
||||||
ref https://discuss.pytorch.org/t/color-quantization/104528/4
|
|
||||||
"""
|
|
||||||
|
|
||||||
orig_dtype = image.dtype
|
|
||||||
|
|
||||||
C, H, W = image.shape
|
|
||||||
n_colors = palette.shape[0]
|
|
||||||
|
|
||||||
# Easier to work with list of colors
|
|
||||||
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]
|
|
||||||
|
|
||||||
# Repeat image so that there are n_color number of columns of the same image
|
|
||||||
flat_img_per_color = flat_img.unsqueeze(1).expand(-1, n_colors, -1) # [H*W, C] -> [H*W, n_colors, C]
|
|
||||||
|
|
||||||
# Get euclidean distance between each pixel in each column and the column's respective color
|
|
||||||
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
|
|
||||||
squared_distance = (flat_img_per_color - palette.unsqueeze(0)) ** 2
|
|
||||||
euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors]
|
|
||||||
|
|
||||||
# Get the shortest distance (one value per row (H*W) is selected)
|
|
||||||
min_distances, min_indices = torch.min(euclidean_distance, dim=-1) # [H*W, n_colors] -> [H*W]
|
|
||||||
|
|
||||||
# Create a mask for the closest colors
|
|
||||||
one_hot_mask = torch.nn.functional.one_hot(min_indices, num_classes=n_colors).float() # [H*W, n_colors]
|
|
||||||
|
|
||||||
# Multiply the mask with the palette colors to get the quantized image
|
|
||||||
quantized = torch.matmul(one_hot_mask, palette) # [H*W, n_colors] @ [n_colors, C] -> [H*W, C]
|
|
||||||
|
|
||||||
# Reshape it back to the original input format.
|
|
||||||
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]
|
|
||||||
|
|
||||||
return quantized_img.to(orig_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def color_block_imgs(img, neg1_1=False):
|
|
||||||
# expects values 0 - 1
|
|
||||||
orig_dtype = img.dtype
|
|
||||||
if neg1_1:
|
|
||||||
img = img * 0.5 + 0.5
|
|
||||||
|
|
||||||
img = img * 255
|
|
||||||
img = img.clamp(0, 255)
|
|
||||||
img = img.to(torch.uint8)
|
|
||||||
|
|
||||||
img_chunks = torch.chunk(img, img.shape[0], dim=0)
|
|
||||||
|
|
||||||
posterized_chunks = []
|
|
||||||
|
|
||||||
for chunk in img_chunks:
|
|
||||||
img_size = (chunk.shape[2] + chunk.shape[3]) // 2
|
|
||||||
# min kernel size of 1% of image, max 10%
|
|
||||||
min_kernel_size = int(img_size * 0.01)
|
|
||||||
max_kernel_size = int(img_size * 0.1)
|
|
||||||
|
|
||||||
# blur first
|
|
||||||
chunk = random_blur(chunk, min_kernel_size=min_kernel_size, max_kernel_size=max_kernel_size, p=0.8)
|
|
||||||
num_colors = random.randint(1, 16)
|
|
||||||
|
|
||||||
resize_to = 16
|
|
||||||
# chunk = torchvision.transforms.functional.posterize(chunk, num_bits_to_use)
|
|
||||||
|
|
||||||
# mean_color = [int(x.item()) for x in torch.mean(chunk.float(), dim=(0, 2, 3))]
|
|
||||||
|
|
||||||
# shrink the image down to num_colors x num_colors
|
|
||||||
shrunk = torchvision.transforms.functional.resize(chunk, [resize_to, resize_to])
|
|
||||||
|
|
||||||
mean_color = [int(x.item()) for x in torch.mean(shrunk.float(), dim=(0, 2, 3))]
|
|
||||||
|
|
||||||
colors = shrunk.view(3, -1).T
|
|
||||||
# remove duplicates
|
|
||||||
colors = torch.unique(colors, dim=0)
|
|
||||||
colors = colors.numpy()
|
|
||||||
colors = colors.tolist()
|
|
||||||
|
|
||||||
use_colors = [random.choice(colors) for _ in range(num_colors)]
|
|
||||||
|
|
||||||
pallette = torch.tensor([
|
|
||||||
[0, 0, 0],
|
|
||||||
mean_color,
|
|
||||||
[255, 255, 255],
|
|
||||||
] + use_colors, dtype=torch.float32)
|
|
||||||
chunk = quantize(chunk.squeeze(0), pallette).unsqueeze(0)
|
|
||||||
|
|
||||||
# chunk = torchvision.transforms.functional.equalize(chunk)
|
|
||||||
# color jitter
|
|
||||||
if random.random() < 0.5:
|
|
||||||
chunk = torchvision.transforms.functional.adjust_contrast(chunk, random.uniform(1.0, 1.5))
|
|
||||||
if random.random() < 0.5:
|
|
||||||
chunk = torchvision.transforms.functional.adjust_saturation(chunk, random.uniform(1.0, 2.0))
|
|
||||||
# if random.random() < 0.5:
|
|
||||||
# chunk = torchvision.transforms.functional.adjust_brightness(chunk, random.uniform(0.5, 1.5))
|
|
||||||
chunk = random_blur(chunk, p=0.6)
|
|
||||||
posterized_chunks.append(chunk)
|
|
||||||
|
|
||||||
img = torch.cat(posterized_chunks, dim=0)
|
|
||||||
img = img.to(orig_dtype)
|
|
||||||
img = img / 255
|
|
||||||
|
|
||||||
if neg1_1:
|
|
||||||
img = img * 2 - 1
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# run through an epoch ang check sizes
|
# run through an epoch ang check sizes
|
||||||
@@ -186,6 +84,7 @@ for epoch in range(args.epochs):
|
|||||||
for batch in tqdm(dataloader):
|
for batch in tqdm(dataloader):
|
||||||
batch: 'DataLoaderBatchDTO'
|
batch: 'DataLoaderBatchDTO'
|
||||||
img_batch = batch.tensor
|
img_batch = batch.tensor
|
||||||
|
batch_size, channels, height, width = img_batch.shape
|
||||||
|
|
||||||
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||||
|
|
||||||
@@ -194,10 +93,14 @@ for epoch in range(args.epochs):
|
|||||||
big_img = torch.cat(chunks, dim=3)
|
big_img = torch.cat(chunks, dim=3)
|
||||||
big_img = big_img.squeeze(0)
|
big_img = big_img.squeeze(0)
|
||||||
|
|
||||||
control_chunks = torch.chunk(batch.control_tensor, batch_size, dim=0)
|
control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
|
||||||
big_control_img = torch.cat(control_chunks, dim=3)
|
big_control_img = torch.cat(control_chunks, dim=3)
|
||||||
big_control_img = big_control_img.squeeze(0) * 2 - 1
|
big_control_img = big_control_img.squeeze(0) * 2 - 1
|
||||||
|
|
||||||
|
|
||||||
|
# resize control image
|
||||||
|
big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
|
||||||
|
|
||||||
big_img = torch.cat([big_img, big_control_img], dim=2)
|
big_img = torch.cat([big_img, big_control_img], dim=2)
|
||||||
|
|
||||||
min_val = big_img.min()
|
min_val = big_img.min()
|
||||||
@@ -208,9 +111,9 @@ for epoch in range(args.epochs):
|
|||||||
# convert to image
|
# convert to image
|
||||||
img = transforms.ToPILImage()(big_img)
|
img = transforms.ToPILImage()(big_img)
|
||||||
|
|
||||||
# show_img(img)
|
show_img(img)
|
||||||
|
|
||||||
# time.sleep(1.0)
|
time.sleep(1.0)
|
||||||
# if not last epoch
|
# if not last epoch
|
||||||
if epoch < args.epochs - 1:
|
if epoch < args.epochs - 1:
|
||||||
trigger_dataloader_setup_epoch(dataloader)
|
trigger_dataloader_setup_epoch(dataloader)
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ class AdapterConfig:
|
|||||||
if num_tokens is None and self.type.startswith('ip'):
|
if num_tokens is None and self.type.startswith('ip'):
|
||||||
if self.type == 'ip+':
|
if self.type == 'ip+':
|
||||||
num_tokens = 16
|
num_tokens = 16
|
||||||
|
num_tokens = 16
|
||||||
elif self.type == 'ip':
|
elif self.type == 'ip':
|
||||||
num_tokens = 4
|
num_tokens = 4
|
||||||
|
|
||||||
|
|||||||
@@ -760,11 +760,15 @@ class ClipImageFileItemDTOMixin:
|
|||||||
# do a flip
|
# do a flip
|
||||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
# image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data
|
|
||||||
if img.width != img.height:
|
if img.width != img.height:
|
||||||
# resize to the smallest dimension
|
|
||||||
min_size = min(img.width, img.height)
|
min_size = min(img.width, img.height)
|
||||||
img = img.resize((min_size, min_size), Image.BICUBIC)
|
if self.dataset_config.square_crop:
|
||||||
|
# center crop to a square
|
||||||
|
img = transforms.CenterCrop(min_size)(img)
|
||||||
|
else:
|
||||||
|
# image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data
|
||||||
|
# resize to the smallest dimension
|
||||||
|
img = img.resize((min_size, min_size), Image.BICUBIC)
|
||||||
|
|
||||||
if self.has_clip_augmentations:
|
if self.has_clip_augmentations:
|
||||||
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
|
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
|
||||||
|
|||||||
Reference in New Issue
Block a user