mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added caching to image sizes so we dont do it every time.
This commit is contained in:
@@ -714,6 +714,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
sigma = sigma.unsqueeze(-1)
|
sigma = sigma.unsqueeze(-1)
|
||||||
return sigma
|
return sigma
|
||||||
|
|
||||||
|
def get_noise(self, latents, batch_size, dtype=torch.float32):
|
||||||
|
# get noise
|
||||||
|
noise = self.sd.get_latent_noise(
|
||||||
|
height=latents.shape[2],
|
||||||
|
width=latents.shape[3],
|
||||||
|
batch_size=batch_size,
|
||||||
|
noise_offset=self.train_config.noise_offset,
|
||||||
|
).to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
return noise
|
||||||
|
|
||||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with self.timer('prepare_prompt'):
|
with self.timer('prepare_prompt'):
|
||||||
@@ -927,12 +938,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
timesteps = torch.stack(timesteps, dim=0)
|
timesteps = torch.stack(timesteps, dim=0)
|
||||||
|
|
||||||
# get noise
|
# get noise
|
||||||
noise = self.sd.get_latent_noise(
|
noise = self.get_noise(latents, batch_size, dtype=dtype)
|
||||||
height=latents.shape[2],
|
|
||||||
width=latents.shape[3],
|
|
||||||
batch_size=batch_size,
|
|
||||||
noise_offset=self.train_config.noise_offset,
|
|
||||||
).to(self.device_torch, dtype=dtype)
|
|
||||||
|
|
||||||
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
||||||
# this will negate any noise offsets
|
# this will negate any noise offsets
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ import json
|
|||||||
# te_path = "google/flan-t5-xl"
|
# te_path = "google/flan-t5-xl"
|
||||||
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||||
model_path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
|
model_path = "/home/jaret/Dev/models/hf/objective-reality-16ch"
|
||||||
te_path = "google/flan-t5-large"
|
te_path = "google/flan-t5-xl"
|
||||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5l_000034000.safetensors"
|
te_aug_path = "/mnt/Train2/out/ip_adapter/t5xl-sd15-16ch_v1/t5xl-sd15-16ch_v1_000115000.safetensors"
|
||||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw"
|
output_path = "/home/jaret/Dev/models/hf/t5xl-sd15-16ch_sd15_v1"
|
||||||
|
|
||||||
|
|
||||||
print("Loading te adapter")
|
print("Loading te adapter")
|
||||||
@@ -28,13 +28,13 @@ is_pixart = "pixart" in model_path.lower()
|
|||||||
|
|
||||||
pipeline_class = StableDiffusionPipeline
|
pipeline_class = StableDiffusionPipeline
|
||||||
|
|
||||||
transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
|
# transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
|
||||||
|
|
||||||
if is_pixart:
|
if is_pixart:
|
||||||
pipeline_class = PixArtSigmaPipeline
|
pipeline_class = PixArtSigmaPipeline
|
||||||
|
|
||||||
if is_diffusers:
|
if is_diffusers:
|
||||||
sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
|
sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
else:
|
else:
|
||||||
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
|
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ if is_pixart:
|
|||||||
unet = sd.transformer
|
unet = sd.transformer
|
||||||
unet_sd = sd.transformer.state_dict()
|
unet_sd = sd.transformer.state_dict()
|
||||||
else:
|
else:
|
||||||
unet = sd.transformer
|
unet = sd.unet
|
||||||
unet_sd = sd.unet.state_dict()
|
unet_sd = sd.unet.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ for epoch in range(args.epochs):
|
|||||||
batch: 'DataLoaderBatchDTO'
|
batch: 'DataLoaderBatchDTO'
|
||||||
img_batch = batch.tensor
|
img_batch = batch.tensor
|
||||||
|
|
||||||
img_batch = color_block_imgs(img_batch, neg1_1=True)
|
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||||
|
|
||||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||||
# put them so they are size by side
|
# put them so they are size by side
|
||||||
@@ -208,9 +208,9 @@ for epoch in range(args.epochs):
|
|||||||
# convert to image
|
# convert to image
|
||||||
img = transforms.ToPILImage()(big_img)
|
img = transforms.ToPILImage()(big_img)
|
||||||
|
|
||||||
show_img(img)
|
# show_img(img)
|
||||||
|
|
||||||
time.sleep(1.0)
|
# time.sleep(1.0)
|
||||||
# if not last epoch
|
# if not last epoch
|
||||||
if epoch < args.epochs - 1:
|
if epoch < args.epochs - 1:
|
||||||
trigger_dataloader_setup_epoch(dataloader)
|
trigger_dataloader_setup_epoch(dataloader)
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ from transformers import (
|
|||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
ConvNextModel,
|
ConvNextModel,
|
||||||
ConvNextForImageClassification,
|
ConvNextForImageClassification,
|
||||||
ConvNextImageProcessor
|
ConvNextImageProcessor,
|
||||||
|
UMT5EncoderModel, LlamaTokenizerFast
|
||||||
)
|
)
|
||||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||||
|
|
||||||
@@ -165,6 +166,23 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
|
|
||||||
# self.te.to = lambda *args, **kwargs: None
|
# self.te.to = lambda *args, **kwargs: None
|
||||||
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
|
self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
|
||||||
|
elif self.config.text_encoder_arch == 'pile-t5':
|
||||||
|
te_kwargs = {}
|
||||||
|
# te_kwargs['load_in_4bit'] = True
|
||||||
|
# te_kwargs['load_in_8bit'] = True
|
||||||
|
te_kwargs['device_map'] = "auto"
|
||||||
|
te_is_quantized = True
|
||||||
|
|
||||||
|
self.te = UMT5EncoderModel.from_pretrained(
|
||||||
|
self.config.text_encoder_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
**te_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.te.to = lambda *args, **kwargs: None
|
||||||
|
self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path)
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||||
elif self.config.text_encoder_arch == 'clip':
|
elif self.config.text_encoder_arch == 'clip':
|
||||||
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
|
self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
|
||||||
dtype=torch_dtype)
|
dtype=torch_dtype)
|
||||||
|
|||||||
@@ -433,7 +433,15 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
|||||||
])
|
])
|
||||||
|
|
||||||
# this might take a while
|
# this might take a while
|
||||||
|
print(f"Dataset: {self.dataset_path}")
|
||||||
print(f" - Preprocessing image dimensions")
|
print(f" - Preprocessing image dimensions")
|
||||||
|
dataset_size_file = os.path.join(self.dataset_path, '.aitk_size.json')
|
||||||
|
if os.path.exists(dataset_size_file):
|
||||||
|
with open(dataset_size_file, 'r') as f:
|
||||||
|
self.size_database = json.load(f)
|
||||||
|
else:
|
||||||
|
self.size_database = {}
|
||||||
|
|
||||||
bad_count = 0
|
bad_count = 0
|
||||||
for file in tqdm(file_list):
|
for file in tqdm(file_list):
|
||||||
try:
|
try:
|
||||||
@@ -442,6 +450,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
|||||||
path=file,
|
path=file,
|
||||||
dataset_config=dataset_config,
|
dataset_config=dataset_config,
|
||||||
dataloader_transforms=self.transform,
|
dataloader_transforms=self.transform,
|
||||||
|
size_database=self.size_database,
|
||||||
)
|
)
|
||||||
self.file_list.append(file_item)
|
self.file_list.append(file_item)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -450,6 +459,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
|||||||
print(e)
|
print(e)
|
||||||
bad_count += 1
|
bad_count += 1
|
||||||
|
|
||||||
|
# save the size database
|
||||||
|
with open(dataset_size_file, 'w') as f:
|
||||||
|
json.dump(self.size_database, f)
|
||||||
|
|
||||||
print(f" - Found {len(self.file_list)} images")
|
print(f" - Found {len(self.file_list)} images")
|
||||||
# print(f" - Found {bad_count} images that are too small")
|
# print(f" - Found {bad_count} images that are too small")
|
||||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from _weakref import ReferenceType
|
from _weakref import ReferenceType
|
||||||
from typing import TYPE_CHECKING, List, Union
|
from typing import TYPE_CHECKING, List, Union
|
||||||
@@ -40,16 +41,22 @@ class FileItemDTO(
|
|||||||
ArgBreakMixin,
|
ArgBreakMixin,
|
||||||
):
|
):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.path = kwargs.get('path', None)
|
self.path = kwargs.get('path', '')
|
||||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||||
# process width and height
|
size_database = kwargs.get('size_database', {})
|
||||||
try:
|
filename = os.path.basename(self.path)
|
||||||
w, h = image_utils.get_image_size(self.path)
|
if filename in size_database:
|
||||||
except image_utils.UnknownImageFormat:
|
w, h = size_database[filename]
|
||||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
else:
|
||||||
f'This process is faster for png, jpeg')
|
# process width and height
|
||||||
img = exif_transpose(Image.open(self.path))
|
try:
|
||||||
h, w = img.size
|
w, h = image_utils.get_image_size(self.path)
|
||||||
|
except image_utils.UnknownImageFormat:
|
||||||
|
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||||
|
f'This process is faster for png, jpeg')
|
||||||
|
img = exif_transpose(Image.open(self.path))
|
||||||
|
h, w = img.size
|
||||||
|
size_database[filename] = (w, h)
|
||||||
self.width: int = w
|
self.width: int = w
|
||||||
self.height: int = h
|
self.height: int = h
|
||||||
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
|
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
|
||||||
|
|||||||
@@ -243,7 +243,7 @@ class TEAdapter(torch.nn.Module):
|
|||||||
self.embeds_store = []
|
self.embeds_store = []
|
||||||
is_pixart = sd.is_pixart
|
is_pixart = sd.is_pixart
|
||||||
|
|
||||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||||
self.token_size = self.te_ref().config.d_model
|
self.token_size = self.te_ref().config.d_model
|
||||||
else:
|
else:
|
||||||
self.token_size = self.te_ref().config.target_hidden_size
|
self.token_size = self.te_ref().config.target_hidden_size
|
||||||
@@ -388,13 +388,25 @@ class TEAdapter(torch.nn.Module):
|
|||||||
# ).input_ids.to(te.device)
|
# ).input_ids.to(te.device)
|
||||||
# outputs = te(input_ids=input_ids)
|
# outputs = te(input_ids=input_ids)
|
||||||
# outputs = outputs.last_hidden_state
|
# outputs = outputs.last_hidden_state
|
||||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
|
||||||
tokenizer,
|
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||||
te,
|
# just use aura pile
|
||||||
text,
|
embeds, attention_mask = train_tools.encode_prompts_auraflow(
|
||||||
truncate=True,
|
tokenizer,
|
||||||
max_length=self.adapter_ref().config.num_tokens,
|
te,
|
||||||
)
|
text,
|
||||||
|
truncate=True,
|
||||||
|
max_length=self.adapter_ref().config.num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||||
|
tokenizer,
|
||||||
|
te,
|
||||||
|
text,
|
||||||
|
truncate=True,
|
||||||
|
max_length=self.adapter_ref().config.num_tokens,
|
||||||
|
)
|
||||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||||
if self.text_projection is not None:
|
if self.text_projection is not None:
|
||||||
# pool the output of embeds ignoring 0 in the attention mask
|
# pool the output of embeds ignoring 0 in the attention mask
|
||||||
|
|||||||
Reference in New Issue
Block a user