Files
sd-webui-old-photo-restoration/Face_Enhancement/util/util.py
Haoming 89a8626838 Squashed commit of the following:
commit cd7a9c103d1ea981ecd236d4e9111fd3c1cd6c2b
Author: Haoming <hmstudy02@gmail.com>
Date:   Tue Dec 19 11:33:44 2023 +0800

    add README

commit 30127cbb2a8e5f461c540729dc7ad457f66eb94c
Author: Haoming <hmstudy02@gmail.com>
Date:   Tue Dec 19 11:12:16 2023 +0800

    fix Face Enhancement distortion

commit 6d52de5368c6cfbd9342465b5238725c186e00b9
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 18:27:25 2023 +0800

    better? args handling

commit 0d1938b59eb77a038ee0a91a66b07fb9d7b3d6d4
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 17:40:19 2023 +0800

    bug fix related to Scratch

commit 8315cd05ffeb2d651b4c57d70bf04b413ca8901d
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 17:24:52 2023 +0800

    implement step 2 ~ 4

commit a5feb04b3980bdd80c6b012a94c743ba48cdfe39
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 11:55:20 2023 +0800

    process scratch

commit 3b18f7b042
Author: Haoming <hmstudy02@gmail.com>
Date:   Wed Dec 13 11:57:20 2023 +0800

    "init"

commit d0148e0e82
Author: Haoming <hmstudy02@gmail.com>
Date:   Wed Dec 13 10:34:39 2023 +0800

    clone repo
2023-12-19 11:35:38 +08:00

218 lines
6.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import importlib
import torch
from argparse import Namespace
import numpy as np
from PIL import Image
import os
import argparse
import dill as pickle
def save_obj(obj, name):
with open(name, "wb") as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_obj(name):
with open(name, "rb") as f:
return pickle.load(f)
def copyconf(default_opt, **kwargs):
conf = argparse.Namespace(**vars(default_opt))
for key in kwargs:
print(key, kwargs[key])
setattr(conf, key, kwargs[key])
return conf
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
if image_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(image_tensor.size(0)):
one_image = image_tensor[b]
one_image_np = tensor2im(one_image)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
return images_np
if image_tensor.dim() == 2:
image_tensor = image_tensor.unsqueeze(0)
image_numpy = image_tensor.detach().cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1:
image_numpy = image_numpy[:, :, 0]
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
if label_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(label_tensor.size(0)):
one_image = label_tensor[b]
one_image_np = tensor2label(one_image, n_label, imtype)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
# if tile:
# images_tiled = tile_images(images_np)
# return images_tiled
# else:
# images_np = images_np[0]
# return images_np
return images_np
if label_tensor.dim() == 1:
return np.zeros((64, 64, 3), dtype=np.uint8)
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
result = label_numpy.astype(imtype)
return result
def save_image(image_numpy, image_path, create_dir=False):
if create_dir:
os.makedirs(os.path.dirname(image_path), exist_ok=True)
if len(image_numpy.shape) == 2:
image_numpy = np.expand_dims(image_numpy, axis=2)
if image_numpy.shape[2] == 1:
image_numpy = np.repeat(image_numpy, 3, 2)
image_pil = Image.fromarray(image_numpy)
# save to png
image_pil.save(image_path.replace(".jpg", ".png"))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
"""
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
"""
return [atoi(c) for c in re.split("(\d+)", text)]
def natural_sort(items):
items.sort(key=natural_keys)
def str2bool(v):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace("_", "").lower()
if module == 'models.networks.generator':
from ..models.networks import generator as clslib
elif module == 'models.networks.encoder':
from ..models.networks import encoder as clslib
else:
raise NotImplementedError(f'Module: {module}')
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
if cls is None:
print(
"In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
% (module, target_cls_name)
)
exit(0)
return cls
def save_network(net, label, epoch, opt):
save_filename = "%s_net_%s.pth" % (epoch, label)
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
torch.save(net.cpu().state_dict(), save_path)
if len(opt.gpu_ids) and torch.cuda.is_available():
net.cuda()
def load_network(net, label, epoch, opt):
save_filename = "%s_net_%s.pth" % (epoch, label)
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
save_path = os.path.join(save_dir, save_filename)
if os.path.exists(save_path):
weights = torch.load(save_path)
net.load_state_dict(weights)
return net
###############################################################################
# Code from
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
# Modified so it complies with the Citscape label map colors
###############################################################################
def uint82bin(n, count=8):
"""returns the binary of integer n, count refers to amount of bits"""
return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
class Colorize(object):
def __init__(self, n=35):
self.cmap = labelcolormap(n)
self.cmap = torch.from_numpy(self.cmap[:n])
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = (label == gray_image[0]).cpu()
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image