Added caching to image sizes so we dont do it every time.

This commit is contained in:
Jaret Burkett
2024-07-15 19:07:41 -06:00
parent e4558dff4b
commit 58dffd43a8
7 changed files with 90 additions and 34 deletions

View File

@@ -714,6 +714,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
sigma = sigma.unsqueeze(-1)
return sigma
def get_noise(self, latents, batch_size, dtype=torch.float32):
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
return noise
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
with self.timer('prepare_prompt'):
@@ -927,12 +938,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
timesteps = torch.stack(timesteps, dim=0)
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
noise = self.get_noise(latents, batch_size, dtype=dtype)
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
# this will negate any noise offsets

View File

@@ -11,10 +11,10 @@ import json
# te_path = "google/flan-t5-xl"
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
model_path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
te_path = "google/flan-t5-large"
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5l_000034000.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw"
model_path = "/home/jaret/Dev/models/hf/objective-reality-16ch"
te_path = "google/flan-t5-xl"
te_aug_path = "/mnt/Train2/out/ip_adapter/t5xl-sd15-16ch_v1/t5xl-sd15-16ch_v1_000115000.safetensors"
output_path = "/home/jaret/Dev/models/hf/t5xl-sd15-16ch_sd15_v1"
print("Loading te adapter")
@@ -28,13 +28,13 @@ is_pixart = "pixart" in model_path.lower()
pipeline_class = StableDiffusionPipeline
transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
# transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
if is_pixart:
pipeline_class = PixArtSigmaPipeline
if is_diffusers:
sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16)
else:
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
@@ -50,7 +50,7 @@ if is_pixart:
unet = sd.transformer
unet_sd = sd.transformer.state_dict()
else:
unet = sd.transformer
unet = sd.unet
unet_sd = sd.unet.state_dict()

View File

@@ -187,7 +187,7 @@ for epoch in range(args.epochs):
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
img_batch = color_block_imgs(img_batch, neg1_1=True)
# img_batch = color_block_imgs(img_batch, neg1_1=True)
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
@@ -208,9 +208,9 @@ for epoch in range(args.epochs):
# convert to image
img = transforms.ToPILImage()(big_img)
show_img(img)
# show_img(img)
time.sleep(1.0)
# time.sleep(1.0)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)

View File

@@ -39,7 +39,8 @@ from transformers import (
AutoImageProcessor,
ConvNextModel,
ConvNextForImageClassification,
ConvNextImageProcessor
ConvNextImageProcessor,
UMT5EncoderModel, LlamaTokenizerFast
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
@@ -165,6 +166,23 @@ class CustomAdapter(torch.nn.Module):
# self.te.to = lambda *args, **kwargs: None
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
elif self.config.text_encoder_arch == 'pile-t5':
te_kwargs = {}
# te_kwargs['load_in_4bit'] = True
# te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
self.te = UMT5EncoderModel.from_pretrained(
self.config.text_encoder_path,
torch_dtype=torch_dtype,
**te_kwargs
)
# self.te.to = lambda *args, **kwargs: None
self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path)
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
elif self.config.text_encoder_arch == 'clip':
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
dtype=torch_dtype)

View File

@@ -433,7 +433,15 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
])
# this might take a while
print(f"Dataset: {self.dataset_path}")
print(f" - Preprocessing image dimensions")
dataset_size_file = os.path.join(self.dataset_path, '.aitk_size.json')
if os.path.exists(dataset_size_file):
with open(dataset_size_file, 'r') as f:
self.size_database = json.load(f)
else:
self.size_database = {}
bad_count = 0
for file in tqdm(file_list):
try:
@@ -442,6 +450,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
path=file,
dataset_config=dataset_config,
dataloader_transforms=self.transform,
size_database=self.size_database,
)
self.file_list.append(file_item)
except Exception as e:
@@ -450,6 +459,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
print(e)
bad_count += 1
# save the size database
with open(dataset_size_file, 'w') as f:
json.dump(self.size_database, f)
print(f" - Found {len(self.file_list)} images")
# print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"

View File

@@ -1,3 +1,4 @@
import os
import weakref
from _weakref import ReferenceType
from typing import TYPE_CHECKING, List, Union
@@ -40,16 +41,22 @@ class FileItemDTO(
ArgBreakMixin,
):
def __init__(self, *args, **kwargs):
self.path = kwargs.get('path', None)
self.path = kwargs.get('path', '')
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
# process width and height
try:
w, h = image_utils.get_image_size(self.path)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
img = exif_transpose(Image.open(self.path))
h, w = img.size
size_database = kwargs.get('size_database', {})
filename = os.path.basename(self.path)
if filename in size_database:
w, h = size_database[filename]
else:
# process width and height
try:
w, h = image_utils.get_image_size(self.path)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
img = exif_transpose(Image.open(self.path))
h, w = img.size
size_database[filename] = (w, h)
self.width: int = w
self.height: int = h
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)

View File

@@ -243,7 +243,7 @@ class TEAdapter(torch.nn.Module):
self.embeds_store = []
is_pixart = sd.is_pixart
if self.adapter_ref().config.text_encoder_arch == "t5":
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
self.token_size = self.te_ref().config.d_model
else:
self.token_size = self.te_ref().config.target_hidden_size
@@ -388,13 +388,25 @@ class TEAdapter(torch.nn.Module):
# ).input_ids.to(te.device)
# outputs = te(input_ids=input_ids)
# outputs = outputs.last_hidden_state
embeds, attention_mask = train_tools.encode_prompts_pixart(
tokenizer,
te,
text,
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
# just use aura pile
embeds, attention_mask = train_tools.encode_prompts_auraflow(
tokenizer,
te,
text,
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
else:
embeds, attention_mask = train_tools.encode_prompts_pixart(
tokenizer,
te,
text,
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
if self.text_projection is not None:
# pool the output of embeds ignoring 0 in the attention mask