Bug fixes. added ability to use l1 loss. varous other tests and improvements

This commit is contained in:
Jaret Burkett
2024-01-31 06:30:54 -07:00
parent 92b9c71d44
commit 1ae1017748
9 changed files with 474 additions and 23 deletions

View File

@@ -12,7 +12,7 @@ import numpy as np
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPImageProcessor
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
@@ -570,9 +570,18 @@ class ClipImageFileItemDTOMixin:
self.has_clip_image = False
self.clip_image_path: Union[str, None] = None
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.clip_image_embeds: Union[dict, None] = None
self.clip_image_embeds_unconditional: Union[dict, None] = None
self.has_clip_augmentations = False
self.clip_image_aug_transform: Union[None, A.Compose] = None
self.clip_image_processor: Union[None, CLIPImageProcessor] = None
self.clip_image_encoder_path: Union[str, None] = None
self.is_caching_clip_vision_to_disk = False
self.is_vision_clip_cached = False
self.clip_vision_is_quad = False
self.clip_vision_load_device = 'cpu'
self.clip_vision_unconditional_paths: Union[List[str], None] = None
self._clip_vision_embeddings_path: Union[str, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.clip_image_path is not None:
# copy the clip image processor so the dataloader can do it
@@ -633,7 +642,45 @@ class ClipImageFileItemDTOMixin:
return augmented_tensor
def get_clip_vision_info_dict(self: 'FileItemDTO'):
item = OrderedDict([
("image_encoder_path", self.clip_image_encoder_path),
("filename", os.path.basename(self.clip_image_path)),
("is_quad", self.clip_vision_is_quad)
])
# when adding items, do it after so we dont change old latents
if self.flip_x:
item["flip_x"] = True
if self.flip_y:
item["flip_y"] = True
return item
def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False):
if self._clip_vision_embeddings_path is not None and not recalculate:
return self._clip_vision_embeddings_path
else:
# we store latents in a folder in same path as image called _latent_cache
img_dir = os.path.dirname(self.clip_image_path)
latent_dir = os.path.join(img_dir, '_clip_vision_cache')
hash_dict = self.get_clip_vision_info_dict()
filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0]
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._clip_vision_embeddings_path
def load_clip_image(self: 'FileItemDTO'):
if self.is_vision_clip_cached:
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
# get a random unconditional image
if self.clip_vision_unconditional_paths is not None:
unconditional_path = random.choice(self.clip_vision_unconditional_paths)
self.clip_image_embeds_unconditional = load_file(unconditional_path)
return
img = Image.open(self.clip_image_path).convert('RGB')
try:
img = exif_transpose(img)
@@ -683,6 +730,7 @@ class ClipImageFileItemDTOMixin:
def cleanup_clip_image(self: 'FileItemDTO'):
self.clip_image_tensor = None
self.clip_image_embeds = None
@@ -1273,7 +1321,7 @@ class LatentCachingMixin:
del latent
del file_item.tensor
flush(garbage_collect=False)
# flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1
# flush every 100
@@ -1282,3 +1330,176 @@ class LatentCachingMixin:
# restore device state
self.sd.restore_device_state()
class CLIPCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.clip_vision_num_unconditional_cache = 20
self.clip_vision_unconditional_cache = []
def cache_clip_vision_to_disk(self: 'AiToolkitDataset'):
if not self.is_caching_clip_vision_to_disk:
return
with torch.no_grad():
print(f"Caching clip vision for {self.dataset_path}")
print(" - Saving clip to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_clip')
# make sure the adapter has attributes
if self.sd.adapter is None:
raise Exception("Error: must have an adapter to cache clip vision to disk")
clip_image_processor: CLIPImageProcessor = None
if hasattr(self.sd.adapter, 'clip_image_processor'):
clip_image_processor = self.sd.adapter.clip_image_processor
if clip_image_processor is None:
raise Exception("Error: must have a clip image processor to cache clip vision to disk")
vision_encoder: CLIPVisionModelWithProjection = None
if hasattr(self.sd.adapter, 'image_encoder'):
vision_encoder = self.sd.adapter.image_encoder
if hasattr(self.sd.adapter, 'vision_encoder'):
vision_encoder = self.sd.adapter.vision_encoder
if vision_encoder is None:
raise Exception("Error: must have a vision encoder to cache clip vision to disk")
# move vision encoder to device
vision_encoder.to(self.sd.device)
is_quad = self.sd.adapter.config.quad_image
image_encoder_path = self.sd.adapter.config.image_encoder_path
dtype = self.sd.torch_dtype
device = self.sd.device_torch
if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero:
# just to do this, we did :)
# need more samples as it is random noise
self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache
else:
# only need one since it doesnt change
self.clip_vision_num_unconditional_cache = 1
# cache unconditionals
print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk")
clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache')
unconditional_paths = []
is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero
for i in range(self.clip_vision_num_unconditional_cache):
hash_dict = OrderedDict([
("image_encoder_path", image_encoder_path),
("is_quad", is_quad),
("is_noise_zero", is_noise_zero),
])
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors')
if os.path.exists(uncond_path):
# skip it
unconditional_paths.append(uncond_path)
continue
# generate a random image
img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size)
if is_noise_zero:
tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32)
else:
tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32)
clip_image = clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
if is_quad:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
clip_output = vision_encoder(
clip_image.to(device, dtype=dtype),
output_hidden_states=True
)
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
state_dict = OrderedDict([
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
])
os.makedirs(os.path.dirname(uncond_path), exist_ok=True)
save_file(state_dict, uncond_path)
unconditional_paths.append(uncond_path)
self.clip_vision_unconditional_cache = unconditional_paths
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'):
file_item.is_caching_clip_vision_to_disk = True
file_item.clip_vision_load_device = self.sd.device
file_item.clip_vision_is_quad = is_quad
file_item.clip_image_encoder_path = image_encoder_path
file_item.clip_vision_unconditional_paths = unconditional_paths
if file_item.has_clip_augmentations:
raise Exception("Error: clip vision caching is not supported with clip augmentations")
embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True)
# check if it is saved to disk already
if not os.path.exists(embedding_path):
# load the image first
file_item.load_clip_image()
# add batch dimension
clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype)
if is_quad:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
clip_output = vision_encoder(
clip_image.to(device, dtype=dtype),
output_hidden_states=True
)
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
state_dict = OrderedDict([
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict())
os.makedirs(os.path.dirname(embedding_path), exist_ok=True)
save_file(state_dict, embedding_path, metadata=meta)
del clip_image
del clip_output
del file_item.clip_image_tensor
# flush(garbage_collect=False)
file_item.is_vision_clip_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()