Bug fixes, work on maing IP adapters more customizable.

This commit is contained in:
Jaret Burkett
2023-12-24 08:32:39 -07:00
parent 7703e3a15e
commit 0f8daa5612
7 changed files with 243 additions and 36 deletions

View File

@@ -827,7 +827,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg):
if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction) and not is_reg:
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,

View File

@@ -151,6 +151,16 @@ class AdapterConfig:
self.num_tokens: int = num_tokens
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid
class ClipTokenMakerConfig:
def __init__(self, **kwargs):
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
self.num_tokens: int = kwargs.get('num_tokens', 8)
@@ -393,7 +403,7 @@ class DatasetConfig:
None) # focus mask (black and white. White has higher loss than black)
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi',
None) # if one is set and in json data, will be used as auto crop scale point of interes
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
@@ -402,6 +412,8 @@ class DatasetConfig:
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
self.standardize_images: bool = kwargs.get('standardize_images', False)
# https://albumentations.ai/docs/api_reference/augmentations/transforms
# augmentations are returned as a separate image and cannot currently be cached
self.augmentations: List[dict] = kwargs.get('augmentations', None)

View File

@@ -8,6 +8,7 @@ from typing import List, TYPE_CHECKING
import cv2
import numpy as np
import torch
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
@@ -24,6 +25,45 @@ if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
class RescaleTransform:
"""Transform to rescale images to the range [-1, 1]."""
def __call__(self, image):
return image * 2 - 1
class NormalizeSDXLTransform:
"""
Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
Mean: tensor([ 0.0002, -0.1034, -0.1879])
Standard Deviation: tensor([0.5436, 0.5116, 0.5033])
"""
def __call__(self, image):
return transforms.Normalize(
mean=[0.0002, -0.1034, -0.1879],
std=[0.5436, 0.5116, 0.5033],
)(image)
class NormalizeSD15Transform:
"""
Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
Mean: tensor([-0.1600, -0.2450, -0.3227])
Standard Deviation: tensor([0.5319, 0.4997, 0.5139])
"""
def __call__(self, image):
return transforms.Normalize(
mean=[-0.1600, -0.2450, -0.3227],
std=[0.5319, 0.4997, 0.5139],
)(image)
class ImageDataset(Dataset, CaptionMixin):
def __init__(self, config):
self.config = config
@@ -63,7 +103,7 @@ class ImageDataset(Dataset, CaptionMixin):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
RescaleTransform(),
])
def get_config(self, key, default=None, required=False):
@@ -200,7 +240,7 @@ class PairedImageDataset(Dataset):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
RescaleTransform(),
])
def get_all_prompts(self):
@@ -368,6 +408,23 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
# repeat the list
file_list = file_list * self.dataset_config.num_repeats
if self.dataset_config.standardize_images:
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
NormalizeMethod = NormalizeSDXLTransform
else:
NormalizeMethod = NormalizeSD15Transform
self.transform = transforms.Compose([
transforms.ToTensor(),
RescaleTransform(),
NormalizeMethod(),
])
else:
self.transform = transforms.Compose([
transforms.ToTensor(),
RescaleTransform(),
])
# this might take a while
print(f" - Preprocessing image dimensions")
bad_count = 0
@@ -375,7 +432,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
try:
file_item = FileItemDTO(
path=file,
dataset_config=dataset_config
dataset_config=dataset_config,
dataloader_transforms=self.transform,
)
self.file_list.append(file_item)
except Exception as e:
@@ -411,10 +469,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
if self.dataset_config.flip_x or self.dataset_config.flip_y:
print(f" - Found {len(self.file_list)} images after adding flips")
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
self.setup_epoch()

View File

@@ -48,6 +48,7 @@ class FileItemDTO(
h, w = img.size
self.width: int = w
self.height: int = h
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
super().__init__(*args, **kwargs)
# self.caption_path: str = kwargs.get('caption_path', None)
@@ -64,6 +65,7 @@ class FileItemDTO(
self.flip_y: bool = kwargs.get('flip_x', False)
self.augments: List[str] = self.dataset_config.augments
self.network_weight: float = self.dataset_config.network_weight
self.is_reg = self.dataset_config.is_reg
self.tensor: Union[torch.Tensor, None] = None

View File

@@ -56,6 +56,30 @@ transforms_dict = {
caption_ext_list = ['txt', 'json', 'caption']
def standardize_images(images):
"""
Standardize the given batch of images using the specified mean and std.
Expects values of 0 - 1
Args:
images (torch.Tensor): A batch of images in the shape of (N, C, H, W),
where N is the number of images, C is the number of channels,
H is the height, and W is the width.
Returns:
torch.Tensor: Standardized images.
"""
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# Define the normalization transform
normalize = transforms.Normalize(mean=mean, std=std)
# Apply normalization to each image in the batch
standardized_images = torch.stack([normalize(img) for img in images])
return standardized_images
def clean_caption(caption):
# remove any newlines
caption = caption.replace('\n', ', ')
@@ -520,8 +544,13 @@ class ControlFileItemDTOMixin:
))
else:
raise Exception("Control images not supported for non-bucket datasets")
self.control_tensor = transforms.ToTensor()(img)
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
self.control_tensor = self.augment_spatial_control(img, transform=transform)
else:
self.control_tensor = transform(img)
def cleanup_control(self: 'FileItemDTO'):
self.control_tensor = None
@@ -624,6 +653,8 @@ class AugmentationFileItemDTOMixin:
self.unaugmented_tensor: Union[torch.Tensor, None] = None
# self.augmentations: Union[None, List[Augments]] = None
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.aug_transform: Union[None, A.Compose] = None
self.aug_replay_spatial_transforms = None
self.build_augmentation_transform()
def build_augmentation_transform(self: 'FileItemDTO'):
@@ -643,7 +674,8 @@ class AugmentationFileItemDTOMixin:
# add the method to the list
augmentation_list.append(method(**aug.params))
self.aug_transform = A.Compose(augmentation_list)
# add additional targets so we can augment the control image
self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'})
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
@@ -659,7 +691,17 @@ class AugmentationFileItemDTOMixin:
open_cv_image = open_cv_image[:, :, ::-1].copy()
# apply augmentations
augmented = self.aug_transform(image=open_cv_image)["image"]
transformed = self.aug_transform(image=open_cv_image)
augmented = transformed["image"]
# save just the spatial transforms for controls and masks
augmented_params = transformed["replay"]
spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop',
'ElasticTransform', 'GridDistortion', 'OpticalDistortion']
# only store the spatial transforms
augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms]
self.aug_replay_spatial_transforms = augmented_params
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
@@ -671,6 +713,38 @@ class AugmentationFileItemDTOMixin:
return augmented_tensor
# augment control images spatially consistent with transforms done to the main image
def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ):
if self.aug_replay_spatial_transforms is None:
# no transforms
return transform(img)
# save colorspace to convert back to
colorspace = img.mode
# convert to rgb
img = img.convert('RGB')
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# Replay transforms
transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image)
augmented = transformed["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
# convert back to original colorspace
augmented = augmented.convert(colorspace)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
def cleanup_control(self: 'FileItemDTO'):
self.unaugmented_tensor = None
@@ -760,7 +834,13 @@ class MaskFileItemDTOMixin:
else:
raise Exception("Mask images not supported for non-bucket datasets")
self.mask_tensor = transforms.ToTensor()(img)
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
self.mask_tensor = self.augment_spatial_control(img, transform=transform)
else:
self.mask_tensor = transform(img)
self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0)
# convert to grayscale
@@ -776,12 +856,7 @@ class UnconditionalFileItemDTOMixin:
self.unconditional_path: Union[str, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latent: Union[torch.Tensor, None] = None
self.unconditional_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.unconditional_transforms = self.dataloader_transforms
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.unconditional_path is not None:
@@ -835,7 +910,10 @@ class UnconditionalFileItemDTOMixin:
else:
raise Exception("Unconditional images are not supported for non-bucket datasets")
self.unconditional_tensor = self.unconditional_transforms(img)
if self.aug_replay_spatial_transforms:
self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms)
else:
self.unconditional_tensor = self.unconditional_transforms(img)
def cleanup_unconditional(self: 'FileItemDTO'):
self.unconditional_tensor = None

View File

@@ -26,8 +26,15 @@ if TYPE_CHECKING:
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
CLIPVisionModel
CLIPVisionModel,
AutoImageProcessor,
ConvNextModel,
ConvNextForImageClassification,
ConvNextImageProcessor
)
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch.nn.functional as F
@@ -153,13 +160,51 @@ class IPAdapter(torch.nn.Module):
super().__init__()
self.config = adapter_config
self.sd_ref: weakref.ref = weakref.ref(sd)
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = CLIPImageProcessor()
self.device = self.sd_ref().unet.device
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path,
ignore_mismatched_sizes=True)
if self.config.image_encoder_arch == 'clip':
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = CLIPImageProcessor()
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
adapter_config.image_encoder_path,
ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'vit':
try:
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = ViTFeatureExtractor()
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'convnext':
try:
self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
print(f"could not load image processor from {adapter_config.image_encoder_path}")
self.clip_image_processor = ConvNextImageProcessor(
size=320,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
)
self.image_encoder = ConvNextForImageClassification.from_pretrained(
adapter_config.image_encoder_path,
use_safetensors=True,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
elif self.config.image_encoder_arch == 'vit-hybrid':
try:
self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
except EnvironmentError:
print(f"could not load image processor from {adapter_config.image_encoder_path}")
self.clip_image_processor = ViTHybridImageProcessor(
size=320,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
)
self.image_encoder = ViTHybridForImageClassification.from_pretrained(
adapter_config.image_encoder_path,
use_safetensors=True,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
else:
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
self.current_scale = 1.0
self.is_active = True
if adapter_config.type == 'ip':
@@ -181,7 +226,7 @@ class IPAdapter(torch.nn.Module):
dim_head=64,
heads=heads,
num_queries=self.config.num_tokens, # usually 16
embedding_dim=self.image_encoder.config.hidden_size,
embedding_dim=self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1],
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
)
@@ -239,6 +284,10 @@ class IPAdapter(torch.nn.Module):
self.set_scale(1.0)
if self.config.train_image_encoder:
self.image_encoder.train()
self.image_encoder.requires_grad_(True)
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.image_encoder.to(*args, **kwargs)
@@ -280,7 +329,7 @@ class IPAdapter(torch.nn.Module):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
@@ -307,15 +356,18 @@ class IPAdapter(torch.nn.Module):
do_resize=True,
do_rescale=False,
).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if drop:
clip_image = clip_image * 0
with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
clip_output = self.image_encoder(clip_image.requires_grad_(True)
, output_hidden_states=True)
else:
self.image_encoder.eval()
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
clip_image_embeds = clip_output.hidden_states[-2]
return clip_image_embeds
@@ -332,8 +384,16 @@ class IPAdapter(torch.nn.Module):
yield from self.image_proj_model.parameters(recurse)
if self.config.train_image_encoder:
yield from self.image_encoder.parameters(recurse)
# if self.config.train_image_encoder:
# yield from self.image_encoder.parameters(recurse)
# self.image_encoder.train()
# else:
# for attn_processor in self.adapter_modules:
# yield from attn_processor.parameters(recurse)
# yield from self.image_proj_model.parameters(recurse)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:

View File

@@ -39,10 +39,11 @@ class PromptEmbeds:
return self
def detach(self):
self.text_embeds = self.text_embeds.detach()
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.detach()
return self
new_embeds = self.clone()
new_embeds.text_embeds = new_embeds.text_embeds.detach()
if new_embeds.pooled_embeds is not None:
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
return new_embeds
def clone(self):
if self.pooled_embeds is not None: