mirror of
https://github.com/Haoming02/sd-webui-old-photo-restoration.git
synced 2026-05-01 03:31:48 +00:00
commit cd7a9c103d1ea981ecd236d4e9111fd3c1cd6c2b Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:33:44 2023 +0800 add README commit 30127cbb2a8e5f461c540729dc7ad457f66eb94c Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:12:16 2023 +0800 fix Face Enhancement distortion commit 6d52de5368c6cfbd9342465b5238725c186e00b9 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 18:27:25 2023 +0800 better? args handling commit 0d1938b59eb77a038ee0a91a66b07fb9d7b3d6d4 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:40:19 2023 +0800 bug fix related to Scratch commit 8315cd05ffeb2d651b4c57d70bf04b413ca8901d Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:24:52 2023 +0800 implement step 2 ~ 4 commit a5feb04b3980bdd80c6b012a94c743ba48cdfe39 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 11:55:20 2023 +0800 process scratch commit3b18f7b042Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 11:57:20 2023 +0800 "init" commitd0148e0e82Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 10:34:39 2023 +0800 clone repo
109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from .base_dataset import BaseDataset, get_params, get_transform
|
|
from PIL import Image
|
|
from ..util import util
|
|
import os
|
|
|
|
|
|
class Pix2pixDataset(BaseDataset):
|
|
@staticmethod
|
|
def modify_commandline_options(parser, is_train):
|
|
parser.add_argument(
|
|
"--no_pairing_check",
|
|
action="store_true",
|
|
help="If specified, skip sanity check of correct label-image file pairing",
|
|
)
|
|
return parser
|
|
|
|
def initialize(self, opt):
|
|
self.opt = opt
|
|
|
|
label_paths, image_paths, instance_paths = self.get_paths(opt)
|
|
|
|
util.natural_sort(label_paths)
|
|
util.natural_sort(image_paths)
|
|
if not opt.no_instance:
|
|
util.natural_sort(instance_paths)
|
|
|
|
label_paths = label_paths[: opt.max_dataset_size]
|
|
image_paths = image_paths[: opt.max_dataset_size]
|
|
instance_paths = instance_paths[: opt.max_dataset_size]
|
|
|
|
if not opt.no_pairing_check:
|
|
for path1, path2 in zip(label_paths, image_paths):
|
|
assert self.paths_match(path1, path2), (
|
|
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this."
|
|
% (path1, path2)
|
|
)
|
|
|
|
self.label_paths = label_paths
|
|
self.image_paths = image_paths
|
|
self.instance_paths = instance_paths
|
|
|
|
size = len(self.label_paths)
|
|
self.dataset_size = size
|
|
|
|
def get_paths(self, opt):
|
|
label_paths = []
|
|
image_paths = []
|
|
instance_paths = []
|
|
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
|
|
return label_paths, image_paths, instance_paths
|
|
|
|
def paths_match(self, path1, path2):
|
|
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
|
|
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
|
|
return filename1_without_ext == filename2_without_ext
|
|
|
|
def __getitem__(self, index):
|
|
# Label Image
|
|
label_path = self.label_paths[index]
|
|
label = Image.open(label_path)
|
|
params = get_params(self.opt, label.size)
|
|
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
|
label_tensor = transform_label(label) * 255.0
|
|
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
|
|
|
|
# input image (real images)
|
|
image_path = self.image_paths[index]
|
|
assert self.paths_match(
|
|
label_path, image_path
|
|
), "The label_path %s and image_path %s don't match." % (label_path, image_path)
|
|
image = Image.open(image_path)
|
|
image = image.convert("RGB")
|
|
|
|
transform_image = get_transform(self.opt, params)
|
|
image_tensor = transform_image(image)
|
|
|
|
# if using instance maps
|
|
if self.opt.no_instance:
|
|
instance_tensor = 0
|
|
else:
|
|
instance_path = self.instance_paths[index]
|
|
instance = Image.open(instance_path)
|
|
if instance.mode == "L":
|
|
instance_tensor = transform_label(instance) * 255
|
|
instance_tensor = instance_tensor.long()
|
|
else:
|
|
instance_tensor = transform_label(instance)
|
|
|
|
input_dict = {
|
|
"label": label_tensor,
|
|
"instance": instance_tensor,
|
|
"image": image_tensor,
|
|
"path": image_path,
|
|
}
|
|
|
|
# Give subclasses a chance to modify the final output
|
|
self.postprocess(input_dict)
|
|
|
|
return input_dict
|
|
|
|
def postprocess(self, input_dict):
|
|
return input_dict
|
|
|
|
def __len__(self):
|
|
return self.dataset_size
|