mirror of
https://github.com/Haoming02/sd-webui-old-photo-restoration.git
synced 2026-01-26 19:29:52 +00:00
commit cd7a9c103d1ea981ecd236d4e9111fd3c1cd6c2b Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:33:44 2023 +0800 add README commit 30127cbb2a8e5f461c540729dc7ad457f66eb94c Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:12:16 2023 +0800 fix Face Enhancement distortion commit 6d52de5368c6cfbd9342465b5238725c186e00b9 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 18:27:25 2023 +0800 better? args handling commit 0d1938b59eb77a038ee0a91a66b07fb9d7b3d6d4 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:40:19 2023 +0800 bug fix related to Scratch commit 8315cd05ffeb2d651b4c57d70bf04b413ca8901d Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:24:52 2023 +0800 implement step 2 ~ 4 commit a5feb04b3980bdd80c6b012a94c743ba48cdfe39 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 11:55:20 2023 +0800 process scratch commit3b18f7b042Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 11:57:20 2023 +0800 "init" commitd0148e0e82Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 10:34:39 2023 +0800 clone repo
102 lines
2.7 KiB
Python
102 lines
2.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import torch.utils.data as data
|
|
from PIL import Image
|
|
import os
|
|
|
|
IMG_EXTENSIONS = [
|
|
".jpg",
|
|
".JPG",
|
|
".jpeg",
|
|
".JPEG",
|
|
".png",
|
|
".PNG",
|
|
".ppm",
|
|
".PPM",
|
|
".bmp",
|
|
".BMP",
|
|
".tiff",
|
|
".webp",
|
|
]
|
|
|
|
|
|
def is_image_file(filename):
|
|
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
|
|
|
|
|
def make_dataset_rec(dir, images):
|
|
assert os.path.isdir(dir), "%s is not a valid directory" % dir
|
|
|
|
for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
|
|
for fname in fnames:
|
|
if is_image_file(fname):
|
|
path = os.path.join(root, fname)
|
|
images.append(path)
|
|
|
|
|
|
def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
|
|
images = []
|
|
|
|
if read_cache:
|
|
possible_filelist = os.path.join(dir, "files.list")
|
|
if os.path.isfile(possible_filelist):
|
|
with open(possible_filelist, "r") as f:
|
|
images = f.read().splitlines()
|
|
return images
|
|
|
|
if recursive:
|
|
make_dataset_rec(dir, images)
|
|
else:
|
|
assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
|
|
|
|
for root, dnames, fnames in sorted(os.walk(dir)):
|
|
for fname in fnames:
|
|
if is_image_file(fname):
|
|
path = os.path.join(root, fname)
|
|
images.append(path)
|
|
|
|
if write_cache:
|
|
filelist_cache = os.path.join(dir, "files.list")
|
|
with open(filelist_cache, "w") as f:
|
|
for path in images:
|
|
f.write("%s\n" % path)
|
|
print("wrote filelist cache at %s" % filelist_cache)
|
|
|
|
return images
|
|
|
|
|
|
def default_loader(path):
|
|
return Image.open(path).convert("RGB")
|
|
|
|
|
|
class ImageFolder(data.Dataset):
|
|
def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
|
|
imgs = make_dataset(root)
|
|
if len(imgs) == 0:
|
|
raise (
|
|
RuntimeError(
|
|
"Found 0 images in: " + root + "\n"
|
|
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
|
|
)
|
|
)
|
|
|
|
self.root = root
|
|
self.imgs = imgs
|
|
self.transform = transform
|
|
self.return_paths = return_paths
|
|
self.loader = loader
|
|
|
|
def __getitem__(self, index):
|
|
path = self.imgs[index]
|
|
img = self.loader(path)
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
if self.return_paths:
|
|
return img, path
|
|
else:
|
|
return img
|
|
|
|
def __len__(self):
|
|
return len(self.imgs)
|