mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Massive speed increase. Added latent caching both to disk and to memory
This commit is contained in:
@@ -1,14 +1,26 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, List, Dict, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.basic import flush
|
||||
from toolkit.buckets import get_bucket_for_image_size
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
@@ -219,7 +231,9 @@ class ImageProcessingDTOMixin:
|
||||
self: 'FileItemDTO',
|
||||
transform: Union[None, transforms.Compose]
|
||||
):
|
||||
# todo make sure this matches
|
||||
# if we are caching latents, just do that
|
||||
if self.is_latent_cached:
|
||||
self.get_latent()
|
||||
try:
|
||||
img = Image.open(self.path).convert('RGB')
|
||||
img = exif_transpose(img)
|
||||
@@ -265,3 +279,139 @@ class ImageProcessingDTOMixin:
|
||||
img = transform(img)
|
||||
|
||||
self.tensor = img
|
||||
|
||||
|
||||
class LatentCachingFileItemDTOMixin:
|
||||
def __init__(self):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__()
|
||||
self._encoded_latent: Union[torch.Tensor, None] = None
|
||||
self._latent_path: Union[str, None] = None
|
||||
self.is_latent_cached = False
|
||||
self.is_caching_to_disk = False
|
||||
self.is_caching_to_memory = False
|
||||
self.latent_load_device = 'cpu'
|
||||
# sd1 or sdxl or others
|
||||
self.latent_space_version = 'sd1'
|
||||
# todo, increment this if we change the latent format to invalidate cache
|
||||
self.latent_version = 1
|
||||
|
||||
def get_latent_info_dict(self: 'FileItemDTO'):
|
||||
return OrderedDict([
|
||||
("filename", os.path.basename(self.path)),
|
||||
("scale_to_width", self.scale_to_width),
|
||||
("scale_to_height", self.scale_to_height),
|
||||
("crop_x", self.crop_x),
|
||||
("crop_y", self.crop_y),
|
||||
("crop_width", self.crop_width),
|
||||
("crop_height", self.crop_height),
|
||||
("latent_space_version", self.latent_space_version),
|
||||
("latent_version", self.latent_version),
|
||||
])
|
||||
|
||||
def get_latent_path(self: 'FileItemDTO', recalculate=False):
|
||||
if self._latent_path is not None and not recalculate:
|
||||
return self._latent_path
|
||||
else:
|
||||
# we store latents in a folder in same path as image called _latent_cache
|
||||
img_dir = os.path.dirname(self.path)
|
||||
latent_dir = os.path.join(img_dir, '_latent_cache')
|
||||
hash_dict = self.get_latent_info_dict()
|
||||
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
self._latent_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._latent_path
|
||||
|
||||
def cleanup_latent(self):
|
||||
if self._encoded_latent is not None:
|
||||
if not self.is_caching_to_memory:
|
||||
# we are caching on disk, don't save in memory
|
||||
self._encoded_latent = None
|
||||
else:
|
||||
# move it back to cpu
|
||||
self._encoded_latent = self._encoded_latent.to('cpu')
|
||||
|
||||
def get_latent(self, device=None):
|
||||
if not self.is_latent_cached:
|
||||
return None
|
||||
if self._encoded_latent is None:
|
||||
# load it from disk
|
||||
state_dict = load_file(
|
||||
self.get_latent_path(),
|
||||
device=device if device is not None else self.latent_load_device
|
||||
)
|
||||
self._encoded_latent = state_dict['latent']
|
||||
return self._encoded_latent
|
||||
|
||||
|
||||
class LatentCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.latent_cache = {}
|
||||
|
||||
def cache_latents_all_latents(self: 'AiToolkitDataset'):
|
||||
print(f"Caching latents for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
to_disk = self.is_caching_latents_to_disk
|
||||
to_memory = self.is_caching_latents_to_memory
|
||||
|
||||
if to_disk:
|
||||
print(" - Saving latents to disk")
|
||||
if to_memory:
|
||||
print(" - Keeping latents in memory")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
|
||||
# use tqdm to show progress
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = 'sd1'
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
latent_path = file_item.get_latent_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if os.path.exists(latent_path):
|
||||
if to_memory:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
# add batch dimension
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
save_file(state_dict, latent_path, metadata=meta)
|
||||
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.dtype)
|
||||
|
||||
flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
Reference in New Issue
Block a user