Allow adapter image to be cropped to match bucket cropping

This commit is contained in:
Jaret Burkett
2023-09-25 10:38:54 -06:00
parent 76c764af49
commit c5d49ba661

View File

@@ -68,9 +68,21 @@ class SDTrainer(BaseSDTrainProcess):
adapter_tensors = []
# load images with torch transforms
for idx, adapter_image in enumerate(adapter_images):
img = Image.open(adapter_image)
# resize to match batch shape
img = img.resize((width, height))
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
# to the smallest dimension
img: Image.Image = Image.open(adapter_image)
if img.width > img.height:
# scale down so height is the same as batch
new_height = height
new_width = int(img.width * (height / img.height))
else:
new_width = width
new_height = int(img.height * (width / img.width))
img = img.resize((new_width, new_height))
crop_fn = transforms.CenterCrop((height, width))
# crop the center to match batch
img = crop_fn(img)
img = adapter_transforms(img)
adapter_tensors.append(img)