mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fixes for dataloader
This commit is contained in:
@@ -32,7 +32,7 @@ bucket_tolerance = 64
|
|||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
|
||||||
dataset_config = DatasetConfig(
|
dataset_config = DatasetConfig(
|
||||||
folder_path=dataset_folder,
|
dataset_path=dataset_folder,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
caption_ext='txt',
|
caption_ext='txt',
|
||||||
default_caption='default',
|
default_caption='default',
|
||||||
@@ -48,22 +48,22 @@ for batch in dataloader:
|
|||||||
batch: 'DataLoaderBatchDTO'
|
batch: 'DataLoaderBatchDTO'
|
||||||
img_batch = batch.tensor
|
img_batch = batch.tensor
|
||||||
|
|
||||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
# chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||||
# put them so they are size by side
|
# # put them so they are size by side
|
||||||
big_img = torch.cat(chunks, dim=3)
|
# big_img = torch.cat(chunks, dim=3)
|
||||||
big_img = big_img.squeeze(0)
|
# big_img = big_img.squeeze(0)
|
||||||
|
#
|
||||||
min_val = big_img.min()
|
# min_val = big_img.min()
|
||||||
max_val = big_img.max()
|
# max_val = big_img.max()
|
||||||
|
#
|
||||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||||
|
#
|
||||||
# convert to image
|
# # convert to image
|
||||||
img = transforms.ToPILImage()(big_img)
|
# img = transforms.ToPILImage()(big_img)
|
||||||
|
#
|
||||||
show_img(img)
|
# show_img(img)
|
||||||
|
#
|
||||||
time.sleep(1.0)
|
# time.sleep(1.0)
|
||||||
|
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|||||||
@@ -251,8 +251,8 @@ class ImageProcessingDTOMixin:
|
|||||||
transform: Union[None, transforms.Compose]
|
transform: Union[None, transforms.Compose]
|
||||||
):
|
):
|
||||||
# todo make sure this matches
|
# todo make sure this matches
|
||||||
img = Image.open(self.path).convert('RGB')
|
|
||||||
try:
|
try:
|
||||||
|
img = Image.open(self.path).convert('RGB')
|
||||||
img = exif_transpose(img)
|
img = exif_transpose(img)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user