mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes. added ability to use l1 loss. varous other tests and improvements
This commit is contained in:
@@ -12,7 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPImageProcessor
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
@@ -570,9 +570,18 @@ class ClipImageFileItemDTOMixin:
|
||||
self.has_clip_image = False
|
||||
self.clip_image_path: Union[str, None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.clip_image_embeds: Union[dict, None] = None
|
||||
self.clip_image_embeds_unconditional: Union[dict, None] = None
|
||||
self.has_clip_augmentations = False
|
||||
self.clip_image_aug_transform: Union[None, A.Compose] = None
|
||||
self.clip_image_processor: Union[None, CLIPImageProcessor] = None
|
||||
self.clip_image_encoder_path: Union[str, None] = None
|
||||
self.is_caching_clip_vision_to_disk = False
|
||||
self.is_vision_clip_cached = False
|
||||
self.clip_vision_is_quad = False
|
||||
self.clip_vision_load_device = 'cpu'
|
||||
self.clip_vision_unconditional_paths: Union[List[str], None] = None
|
||||
self._clip_vision_embeddings_path: Union[str, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.clip_image_path is not None:
|
||||
# copy the clip image processor so the dataloader can do it
|
||||
@@ -633,7 +642,45 @@ class ClipImageFileItemDTOMixin:
|
||||
|
||||
return augmented_tensor
|
||||
|
||||
def get_clip_vision_info_dict(self: 'FileItemDTO'):
|
||||
item = OrderedDict([
|
||||
("image_encoder_path", self.clip_image_encoder_path),
|
||||
("filename", os.path.basename(self.clip_image_path)),
|
||||
("is_quad", self.clip_vision_is_quad)
|
||||
])
|
||||
# when adding items, do it after so we dont change old latents
|
||||
if self.flip_x:
|
||||
item["flip_x"] = True
|
||||
if self.flip_y:
|
||||
item["flip_y"] = True
|
||||
return item
|
||||
def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False):
|
||||
if self._clip_vision_embeddings_path is not None and not recalculate:
|
||||
return self._clip_vision_embeddings_path
|
||||
else:
|
||||
# we store latents in a folder in same path as image called _latent_cache
|
||||
img_dir = os.path.dirname(self.clip_image_path)
|
||||
latent_dir = os.path.join(img_dir, '_clip_vision_cache')
|
||||
hash_dict = self.get_clip_vision_info_dict()
|
||||
filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0]
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._clip_vision_embeddings_path
|
||||
|
||||
def load_clip_image(self: 'FileItemDTO'):
|
||||
if self.is_vision_clip_cached:
|
||||
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
|
||||
|
||||
# get a random unconditional image
|
||||
if self.clip_vision_unconditional_paths is not None:
|
||||
unconditional_path = random.choice(self.clip_vision_unconditional_paths)
|
||||
self.clip_image_embeds_unconditional = load_file(unconditional_path)
|
||||
|
||||
return
|
||||
img = Image.open(self.clip_image_path).convert('RGB')
|
||||
try:
|
||||
img = exif_transpose(img)
|
||||
@@ -683,6 +730,7 @@ class ClipImageFileItemDTOMixin:
|
||||
|
||||
def cleanup_clip_image(self: 'FileItemDTO'):
|
||||
self.clip_image_tensor = None
|
||||
self.clip_image_embeds = None
|
||||
|
||||
|
||||
|
||||
@@ -1273,7 +1321,7 @@ class LatentCachingMixin:
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
flush(garbage_collect=False)
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
@@ -1282,3 +1330,176 @@ class LatentCachingMixin:
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
|
||||
class CLIPCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.clip_vision_num_unconditional_cache = 20
|
||||
self.clip_vision_unconditional_cache = []
|
||||
|
||||
def cache_clip_vision_to_disk(self: 'AiToolkitDataset'):
|
||||
if not self.is_caching_clip_vision_to_disk:
|
||||
return
|
||||
with torch.no_grad():
|
||||
print(f"Caching clip vision for {self.dataset_path}")
|
||||
|
||||
print(" - Saving clip to disk")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_clip')
|
||||
|
||||
# make sure the adapter has attributes
|
||||
if self.sd.adapter is None:
|
||||
raise Exception("Error: must have an adapter to cache clip vision to disk")
|
||||
|
||||
clip_image_processor: CLIPImageProcessor = None
|
||||
if hasattr(self.sd.adapter, 'clip_image_processor'):
|
||||
clip_image_processor = self.sd.adapter.clip_image_processor
|
||||
|
||||
if clip_image_processor is None:
|
||||
raise Exception("Error: must have a clip image processor to cache clip vision to disk")
|
||||
|
||||
vision_encoder: CLIPVisionModelWithProjection = None
|
||||
if hasattr(self.sd.adapter, 'image_encoder'):
|
||||
vision_encoder = self.sd.adapter.image_encoder
|
||||
if hasattr(self.sd.adapter, 'vision_encoder'):
|
||||
vision_encoder = self.sd.adapter.vision_encoder
|
||||
|
||||
if vision_encoder is None:
|
||||
raise Exception("Error: must have a vision encoder to cache clip vision to disk")
|
||||
|
||||
# move vision encoder to device
|
||||
vision_encoder.to(self.sd.device)
|
||||
|
||||
is_quad = self.sd.adapter.config.quad_image
|
||||
image_encoder_path = self.sd.adapter.config.image_encoder_path
|
||||
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero:
|
||||
# just to do this, we did :)
|
||||
# need more samples as it is random noise
|
||||
self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache
|
||||
else:
|
||||
# only need one since it doesnt change
|
||||
self.clip_vision_num_unconditional_cache = 1
|
||||
|
||||
# cache unconditionals
|
||||
print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk")
|
||||
clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache')
|
||||
|
||||
unconditional_paths = []
|
||||
|
||||
is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero
|
||||
|
||||
for i in range(self.clip_vision_num_unconditional_cache):
|
||||
hash_dict = OrderedDict([
|
||||
("image_encoder_path", image_encoder_path),
|
||||
("is_quad", is_quad),
|
||||
("is_noise_zero", is_noise_zero),
|
||||
])
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
|
||||
uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors')
|
||||
if os.path.exists(uncond_path):
|
||||
# skip it
|
||||
unconditional_paths.append(uncond_path)
|
||||
continue
|
||||
|
||||
# generate a random image
|
||||
img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size)
|
||||
if is_noise_zero:
|
||||
tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32)
|
||||
else:
|
||||
tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32)
|
||||
clip_image = clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
|
||||
if is_quad:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
|
||||
|
||||
clip_output = vision_encoder(
|
||||
clip_image.to(device, dtype=dtype),
|
||||
output_hidden_states=True
|
||||
)
|
||||
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
|
||||
state_dict = OrderedDict([
|
||||
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
|
||||
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
|
||||
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
|
||||
])
|
||||
|
||||
os.makedirs(os.path.dirname(uncond_path), exist_ok=True)
|
||||
save_file(state_dict, uncond_path)
|
||||
unconditional_paths.append(uncond_path)
|
||||
|
||||
self.clip_vision_unconditional_cache = unconditional_paths
|
||||
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'):
|
||||
file_item.is_caching_clip_vision_to_disk = True
|
||||
file_item.clip_vision_load_device = self.sd.device
|
||||
file_item.clip_vision_is_quad = is_quad
|
||||
file_item.clip_image_encoder_path = image_encoder_path
|
||||
file_item.clip_vision_unconditional_paths = unconditional_paths
|
||||
if file_item.has_clip_augmentations:
|
||||
raise Exception("Error: clip vision caching is not supported with clip augmentations")
|
||||
|
||||
embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if not os.path.exists(embedding_path):
|
||||
# load the image first
|
||||
file_item.load_clip_image()
|
||||
# add batch dimension
|
||||
clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
|
||||
if is_quad:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
|
||||
|
||||
clip_output = vision_encoder(
|
||||
clip_image.to(device, dtype=dtype),
|
||||
output_hidden_states=True
|
||||
)
|
||||
|
||||
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
|
||||
state_dict = OrderedDict([
|
||||
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
|
||||
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
|
||||
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict())
|
||||
os.makedirs(os.path.dirname(embedding_path), exist_ok=True)
|
||||
save_file(state_dict, embedding_path, metadata=meta)
|
||||
|
||||
del clip_image
|
||||
del clip_output
|
||||
del file_item.clip_image_tensor
|
||||
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_vision_clip_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
Reference in New Issue
Block a user