Fixes for sdxl ip adapter training. Bug fixes

This commit is contained in:
Jaret Burkett
2023-12-21 11:15:58 -07:00
parent 0f597f453e
commit 7703e3a15e
5 changed files with 63 additions and 11 deletions

View File

@@ -173,12 +173,9 @@ class SDTrainer(BaseSDTrainProcess):
if torch.isnan(prior_loss).any():
raise ValueError("Prior loss is nan")
prior_loss = prior_loss.mean([1, 2, 3])
loss = loss.mean([1, 2, 3])
if prior_loss is not None:
# prior_loss = prior_loss.mean([1, 2, 3])
loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if self.train_config.learnable_snr_gos:
# add snr_gamma

View File

@@ -423,6 +423,8 @@ class DatasetConfig:
# ip adapter / reference dataset
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -534,6 +534,8 @@ class ClipImageFileItemDTOMixin:
self.has_clip_image = False
self.clip_image_path: Union[str, None] = None
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.has_clip_augmentations = False
self.clip_image_aug_transform: Union[None, A.Compose] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.clip_image_path is not None:
# find the control image path
@@ -548,6 +550,51 @@ class ClipImageFileItemDTOMixin:
self.has_clip_image = True
break
self.build_clip_imag_augmentation_transform()
def build_clip_imag_augmentation_transform(self: 'FileItemDTO'):
if self.dataset_config.clip_image_augmentations is not None and len(self.dataset_config.clip_image_augmentations) > 0:
self.has_clip_augmentations = True
augmentations = [Augments(**aug) for aug in self.dataset_config.clip_image_augmentations]
if self.dataset_config.clip_image_shuffle_augmentations:
random.shuffle(augmentations)
augmentation_list = []
for aug in augmentations:
# make sure method name is valid
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
# get the method
method = getattr(A, aug.method_name)
# add the method to the list
augmentation_list.append(method(**aug.params))
self.clip_image_aug_transform = A.Compose(augmentation_list)
def augment_clip_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
if self.dataset_config.clip_image_shuffle_augmentations:
self.build_clip_imag_augmentation_transform()
# save the original tensor
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# apply augmentations
augmented = self.clip_image_aug_transform(image=open_cv_image)["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
def load_clip_image(self: 'FileItemDTO'):
img = Image.open(self.clip_image_path).convert('RGB')
try:
@@ -558,8 +605,10 @@ class ClipImageFileItemDTOMixin:
# we just scale them to 512x512:
img = img.resize((512, 512), Image.BICUBIC)
self.clip_image_tensor = transforms.ToTensor()(img)
if self.has_clip_augmentations:
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
else:
self.clip_image_tensor = transforms.ToTensor()(img)
def cleanup_clip_image(self: 'FileItemDTO'):
self.clip_image_tensor = None

View File

@@ -432,7 +432,7 @@ def show_img(img, name='AI Toolkit'):
img = np.clip(img, 0, 255).astype(np.uint8)
cv2.imshow(name, img[:, :, ::-1])
k = cv2.waitKey(10) & 0xFF
k = cv2.waitKey(1) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt

View File

@@ -170,12 +170,16 @@ class IPAdapter(torch.nn.Module):
clip_extra_context_tokens=self.config.num_tokens, # usually 4
)
elif adapter_config.type == 'ip+':
heads = 12 if not sd.is_xl else 20
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
# ip-adapter-plus
image_proj_model = Resampler(
dim=sd.unet.config['cross_attention_dim'],
dim=dim,
depth=4,
dim_head=64,
heads=12,
heads=heads,
num_queries=self.config.num_tokens, # usually 16
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=sd.unet.config['cross_attention_dim'],
@@ -266,7 +270,7 @@ class IPAdapter(torch.nn.Module):
def set_scale(self, scale):
self.current_scale = scale
for attn_processor in self.sd_ref().unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
if isinstance(attn_processor, CustomIPAttentionProcessor):
attn_processor.scale = scale
@torch.no_grad()