mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-05 21:19:56 +00:00
Added caching to image sizes so we dont do it every time.
This commit is contained in:
@@ -714,6 +714,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
def get_noise(self, latents, batch_size, dtype=torch.float32):
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
return noise
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
with self.timer('prepare_prompt'):
|
||||
@@ -927,12 +938,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
noise = self.get_noise(latents, batch_size, dtype=dtype)
|
||||
|
||||
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
||||
# this will negate any noise offsets
|
||||
|
||||
@@ -11,10 +11,10 @@ import json
|
||||
# te_path = "google/flan-t5-xl"
|
||||
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
model_path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
|
||||
te_path = "google/flan-t5-large"
|
||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5l_000034000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw"
|
||||
model_path = "/home/jaret/Dev/models/hf/objective-reality-16ch"
|
||||
te_path = "google/flan-t5-xl"
|
||||
te_aug_path = "/mnt/Train2/out/ip_adapter/t5xl-sd15-16ch_v1/t5xl-sd15-16ch_v1_000115000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/t5xl-sd15-16ch_sd15_v1"
|
||||
|
||||
|
||||
print("Loading te adapter")
|
||||
@@ -28,13 +28,13 @@ is_pixart = "pixart" in model_path.lower()
|
||||
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
|
||||
transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
|
||||
# transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
|
||||
|
||||
if is_pixart:
|
||||
pipeline_class = PixArtSigmaPipeline
|
||||
|
||||
if is_diffusers:
|
||||
sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
|
||||
sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
else:
|
||||
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_pixart:
|
||||
unet = sd.transformer
|
||||
unet_sd = sd.transformer.state_dict()
|
||||
else:
|
||||
unet = sd.transformer
|
||||
unet = sd.unet
|
||||
unet_sd = sd.unet.state_dict()
|
||||
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ for epoch in range(args.epochs):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
@@ -208,9 +208,9 @@ for epoch in range(args.epochs):
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
|
||||
show_img(img)
|
||||
# show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
# time.sleep(1.0)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
@@ -39,7 +39,8 @@ from transformers import (
|
||||
AutoImageProcessor,
|
||||
ConvNextModel,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
ConvNextImageProcessor,
|
||||
UMT5EncoderModel, LlamaTokenizerFast
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
@@ -165,6 +166,23 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
# self.te.to = lambda *args, **kwargs: None
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
|
||||
elif self.config.text_encoder_arch == 'pile-t5':
|
||||
te_kwargs = {}
|
||||
# te_kwargs['load_in_4bit'] = True
|
||||
# te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
|
||||
self.te = UMT5EncoderModel.from_pretrained(
|
||||
self.config.text_encoder_path,
|
||||
torch_dtype=torch_dtype,
|
||||
**te_kwargs
|
||||
)
|
||||
|
||||
# self.te.to = lambda *args, **kwargs: None
|
||||
self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
elif self.config.text_encoder_arch == 'clip':
|
||||
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
|
||||
dtype=torch_dtype)
|
||||
|
||||
@@ -433,7 +433,15 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
])
|
||||
|
||||
# this might take a while
|
||||
print(f"Dataset: {self.dataset_path}")
|
||||
print(f" - Preprocessing image dimensions")
|
||||
dataset_size_file = os.path.join(self.dataset_path, '.aitk_size.json')
|
||||
if os.path.exists(dataset_size_file):
|
||||
with open(dataset_size_file, 'r') as f:
|
||||
self.size_database = json.load(f)
|
||||
else:
|
||||
self.size_database = {}
|
||||
|
||||
bad_count = 0
|
||||
for file in tqdm(file_list):
|
||||
try:
|
||||
@@ -442,6 +450,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
path=file,
|
||||
dataset_config=dataset_config,
|
||||
dataloader_transforms=self.transform,
|
||||
size_database=self.size_database,
|
||||
)
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
@@ -450,6 +459,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
print(e)
|
||||
bad_count += 1
|
||||
|
||||
# save the size database
|
||||
with open(dataset_size_file, 'w') as f:
|
||||
json.dump(self.size_database, f)
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
# print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import weakref
|
||||
from _weakref import ReferenceType
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
@@ -40,16 +41,22 @@ class FileItemDTO(
|
||||
ArgBreakMixin,
|
||||
):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.path = kwargs.get('path', None)
|
||||
self.path = kwargs.get('path', '')
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
# process width and height
|
||||
try:
|
||||
w, h = image_utils.get_image_size(self.path)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
img = exif_transpose(Image.open(self.path))
|
||||
h, w = img.size
|
||||
size_database = kwargs.get('size_database', {})
|
||||
filename = os.path.basename(self.path)
|
||||
if filename in size_database:
|
||||
w, h = size_database[filename]
|
||||
else:
|
||||
# process width and height
|
||||
try:
|
||||
w, h = image_utils.get_image_size(self.path)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
img = exif_transpose(Image.open(self.path))
|
||||
h, w = img.size
|
||||
size_database[filename] = (w, h)
|
||||
self.width: int = w
|
||||
self.height: int = h
|
||||
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
|
||||
|
||||
@@ -243,7 +243,7 @@ class TEAdapter(torch.nn.Module):
|
||||
self.embeds_store = []
|
||||
is_pixart = sd.is_pixart
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||
self.token_size = self.te_ref().config.d_model
|
||||
else:
|
||||
self.token_size = self.te_ref().config.target_hidden_size
|
||||
@@ -388,13 +388,25 @@ class TEAdapter(torch.nn.Module):
|
||||
# ).input_ids.to(te.device)
|
||||
# outputs = te(input_ids=input_ids)
|
||||
# outputs = outputs.last_hidden_state
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
tokenizer,
|
||||
te,
|
||||
text,
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||
# just use aura pile
|
||||
embeds, attention_mask = train_tools.encode_prompts_auraflow(
|
||||
tokenizer,
|
||||
te,
|
||||
text,
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
|
||||
else:
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
tokenizer,
|
||||
te,
|
||||
text,
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||
if self.text_projection is not None:
|
||||
# pool the output of embeds ignoring 0 in the attention mask
|
||||
|
||||
Reference in New Issue
Block a user