mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixes for sdxl ip adapter training. Bug fixes
This commit is contained in:
@@ -173,12 +173,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if torch.isnan(prior_loss).any():
|
||||
raise ValueError("Prior loss is nan")
|
||||
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if prior_loss is not None:
|
||||
# prior_loss = prior_loss.mean([1, 2, 3])
|
||||
loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
|
||||
@@ -423,6 +423,8 @@ class DatasetConfig:
|
||||
|
||||
# ip adapter / reference dataset
|
||||
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
|
||||
self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
|
||||
self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
|
||||
@@ -534,6 +534,8 @@ class ClipImageFileItemDTOMixin:
|
||||
self.has_clip_image = False
|
||||
self.clip_image_path: Union[str, None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.has_clip_augmentations = False
|
||||
self.clip_image_aug_transform: Union[None, A.Compose] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.clip_image_path is not None:
|
||||
# find the control image path
|
||||
@@ -548,6 +550,51 @@ class ClipImageFileItemDTOMixin:
|
||||
self.has_clip_image = True
|
||||
break
|
||||
|
||||
self.build_clip_imag_augmentation_transform()
|
||||
|
||||
def build_clip_imag_augmentation_transform(self: 'FileItemDTO'):
|
||||
if self.dataset_config.clip_image_augmentations is not None and len(self.dataset_config.clip_image_augmentations) > 0:
|
||||
self.has_clip_augmentations = True
|
||||
augmentations = [Augments(**aug) for aug in self.dataset_config.clip_image_augmentations]
|
||||
|
||||
if self.dataset_config.clip_image_shuffle_augmentations:
|
||||
random.shuffle(augmentations)
|
||||
|
||||
augmentation_list = []
|
||||
for aug in augmentations:
|
||||
# make sure method name is valid
|
||||
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
|
||||
# get the method
|
||||
method = getattr(A, aug.method_name)
|
||||
# add the method to the list
|
||||
augmentation_list.append(method(**aug.params))
|
||||
|
||||
self.clip_image_aug_transform = A.Compose(augmentation_list)
|
||||
|
||||
def augment_clip_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
|
||||
if self.dataset_config.clip_image_shuffle_augmentations:
|
||||
self.build_clip_imag_augmentation_transform()
|
||||
|
||||
# save the original tensor
|
||||
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
|
||||
|
||||
open_cv_image = np.array(img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# apply augmentations
|
||||
augmented = self.clip_image_aug_transform(image=open_cv_image)["image"]
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# convert to PIL image
|
||||
augmented = Image.fromarray(augmented)
|
||||
|
||||
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
|
||||
|
||||
return augmented_tensor
|
||||
|
||||
def load_clip_image(self: 'FileItemDTO'):
|
||||
img = Image.open(self.clip_image_path).convert('RGB')
|
||||
try:
|
||||
@@ -558,8 +605,10 @@ class ClipImageFileItemDTOMixin:
|
||||
|
||||
# we just scale them to 512x512:
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
self.clip_image_tensor = transforms.ToTensor()(img)
|
||||
if self.has_clip_augmentations:
|
||||
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
|
||||
else:
|
||||
self.clip_image_tensor = transforms.ToTensor()(img)
|
||||
|
||||
def cleanup_clip_image(self: 'FileItemDTO'):
|
||||
self.clip_image_tensor = None
|
||||
|
||||
@@ -432,7 +432,7 @@ def show_img(img, name='AI Toolkit'):
|
||||
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
cv2.imshow(name, img[:, :, ::-1])
|
||||
k = cv2.waitKey(10) & 0xFF
|
||||
k = cv2.waitKey(1) & 0xFF
|
||||
if k == 27: # Esc key to stop
|
||||
print('\nESC pressed, stopping')
|
||||
raise KeyboardInterrupt
|
||||
|
||||
@@ -170,12 +170,16 @@ class IPAdapter(torch.nn.Module):
|
||||
clip_extra_context_tokens=self.config.num_tokens, # usually 4
|
||||
)
|
||||
elif adapter_config.type == 'ip+':
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
|
||||
# ip-adapter-plus
|
||||
image_proj_model = Resampler(
|
||||
dim=sd.unet.config['cross_attention_dim'],
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
heads=heads,
|
||||
num_queries=self.config.num_tokens, # usually 16
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
@@ -266,7 +270,7 @@ class IPAdapter(torch.nn.Module):
|
||||
def set_scale(self, scale):
|
||||
self.current_scale = scale
|
||||
for attn_processor in self.sd_ref().unet.attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor):
|
||||
if isinstance(attn_processor, CustomIPAttentionProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user