mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-02 03:49:47 +00:00
Allow adapter image to be cropped to match bucket cropping
This commit is contained in:
@@ -68,9 +68,21 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_tensors = []
|
||||
# load images with torch transforms
|
||||
for idx, adapter_image in enumerate(adapter_images):
|
||||
img = Image.open(adapter_image)
|
||||
# resize to match batch shape
|
||||
img = img.resize((width, height))
|
||||
# we need to centrally crop the largest dimension of the image to match the batch shape after scaling
|
||||
# to the smallest dimension
|
||||
img: Image.Image = Image.open(adapter_image)
|
||||
if img.width > img.height:
|
||||
# scale down so height is the same as batch
|
||||
new_height = height
|
||||
new_width = int(img.width * (height / img.height))
|
||||
else:
|
||||
new_width = width
|
||||
new_height = int(img.height * (width / img.width))
|
||||
|
||||
img = img.resize((new_width, new_height))
|
||||
crop_fn = transforms.CenterCrop((height, width))
|
||||
# crop the center to match batch
|
||||
img = crop_fn(img)
|
||||
img = adapter_transforms(img)
|
||||
adapter_tensors.append(img)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user