diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index df89544e..4755dd88 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -120,7 +120,13 @@ class ImageDataset(Dataset, CaptionMixin): def __getitem__(self, index): img_path = self.file_list[index] - img = exif_transpose(Image.open(img_path)).convert('RGB') + try: + img = exif_transpose(Image.open(img_path)).convert('RGB') + except Exception as e: + print(f"Error opening image: {img_path}") + print(e) + # make a noise image if we can't open it + img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8)) # Downscale the source image first img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 0ab619dd..e30754de 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -271,6 +271,11 @@ class IPAdapter(torch.nn.Module): else: raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}") + if not self.config.train_image_encoder: + # compile it + print('Compiling image encoder') + torch.compile(self.image_encoder, fullgraph=True) + self.input_size = self.image_encoder.config.image_size if self.config.quad_image: # 4x4 image