diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f98cab23..bb767786 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -68,9 +68,21 @@ class SDTrainer(BaseSDTrainProcess): adapter_tensors = [] # load images with torch transforms for idx, adapter_image in enumerate(adapter_images): - img = Image.open(adapter_image) - # resize to match batch shape - img = img.resize((width, height)) + # we need to centrally crop the largest dimension of the image to match the batch shape after scaling + # to the smallest dimension + img: Image.Image = Image.open(adapter_image) + if img.width > img.height: + # scale down so height is the same as batch + new_height = height + new_width = int(img.width * (height / img.height)) + else: + new_width = width + new_height = int(img.height * (width / img.width)) + + img = img.resize((new_width, new_height)) + crop_fn = transforms.CenterCrop((height, width)) + # crop the center to match batch + img = crop_fn(img) img = adapter_transforms(img) adapter_tensors.append(img)