submodule

This commit is contained in:
Haoming
2024-11-04 15:53:34 +08:00
parent 45cd9fb513
commit b8c6cf1a95
76 changed files with 54 additions and 9381 deletions

7
.gitignore vendored
View File

@@ -1,8 +1 @@
# Junks
__pycache__
*.pyc
*~
# Models
*landmarks.dat
*/checkpoints/*

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "lib_bopb2l"]
path = lib_bopb2l
url = https://github.com/Haoming02/BOP-B2L-Backend

View File

@@ -1,430 +0,0 @@
# Copyright (c) Microsoft Corporation
from skimage.transform import SimilarityTransform
from matplotlib.patches import Rectangle
from PIL import Image, ImageFilter
from skimage.transform import warp
from skimage import img_as_ubyte
import matplotlib.pyplot as plt
import numpy as np
import dlib
import cv2
import os
def calculate_cdf(histogram):
"""
This method calculates the cumulative distribution function
:param array histogram: The values of the histogram
:return: normalized_cdf: The normalized cumulative distribution function
:rtype: array
"""
# Get the cumulative sum of the elements
cdf = histogram.cumsum()
# Normalize the cdf
normalized_cdf = cdf / float(cdf.max())
return normalized_cdf
def calculate_lookup(src_cdf, ref_cdf):
"""
This method creates the lookup table
:param array src_cdf: The cdf for the source image
:param array ref_cdf: The cdf for the reference image
:return: lookup_table: The lookup table
:rtype: array
"""
lookup_table = np.zeros(256)
lookup_val = 0
for src_pixel_val in range(len(src_cdf)):
lookup_val
for ref_pixel_val in range(len(ref_cdf)):
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
lookup_val = ref_pixel_val
break
lookup_table[src_pixel_val] = lookup_val
return lookup_table
def match_histograms(src_image, ref_image):
"""
This method matches the source image histogram to the
reference signal
:param image src_image: The original source image
:param image ref_image: The reference image
:return: image_after_matching
:rtype: image (array)
"""
# Split the images into the different color channels
# b means blue, g means green and r means red
src_b, src_g, src_r = cv2.split(src_image)
ref_b, ref_g, ref_r = cv2.split(ref_image)
# Compute the b, g, and r histograms separately
# The flatten() Numpy method returns a copy of the array c
# collapsed into one dimension.
src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
# Compute the normalized cdf for the source and reference image
src_cdf_blue = calculate_cdf(src_hist_blue)
src_cdf_green = calculate_cdf(src_hist_green)
src_cdf_red = calculate_cdf(src_hist_red)
ref_cdf_blue = calculate_cdf(ref_hist_blue)
ref_cdf_green = calculate_cdf(ref_hist_green)
ref_cdf_red = calculate_cdf(ref_hist_red)
# Make a separate lookup table for each color
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
# Use the lookup function to transform the colors of the original
# source image
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
green_after_transform = cv2.LUT(src_g, green_lookup_table)
red_after_transform = cv2.LUT(src_r, red_lookup_table)
# Put the image back together
image_after_matching = cv2.merge(
[blue_after_transform, green_after_transform, red_after_transform]
)
image_after_matching = cv2.convertScaleAbs(image_after_matching)
return image_after_matching
def _standard_face_pts():
pts = (
np.array(
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
np.float32,
)
/ 256.0
- 1.0
)
return np.reshape(pts, (5, 2))
def _origin_face_pts():
pts = np.array(
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
np.float32,
)
return np.reshape(pts, (5, 2))
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(target_pts, landmark)
return affine
def compute_inverse_transformation_matrix(
img, landmark, normalize, target_face_scale=1.0
):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(landmark, target_pts)
return affine
def show_detection(image, box, landmark):
plt.imshow(image)
print(box[2] - box[0])
plt.gca().add_patch(
Rectangle(
(box[1], box[0]),
box[2] - box[0],
box[3] - box[1],
linewidth=1,
edgecolor="r",
facecolor="none",
)
)
plt.scatter(landmark[0][0], landmark[0][1])
plt.scatter(landmark[1][0], landmark[1][1])
plt.scatter(landmark[2][0], landmark[2][1])
plt.scatter(landmark[3][0], landmark[3][1])
plt.scatter(landmark[4][0], landmark[4][1])
plt.show()
def affine2theta(affine, input_w, input_h, target_w, target_h):
# param = np.linalg.inv(affine)
param = affine
theta = np.zeros([2, 3])
theta[0, 0] = param[0, 0] * input_h / target_h
theta[0, 1] = param[0, 1] * input_w / target_h
theta[0, 2] = (
2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w
) / target_h - 1
theta[1, 0] = param[1, 0] * input_h / target_w
theta[1, 1] = param[1, 1] * input_w / target_w
theta[1, 2] = (
2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w
) / target_w - 1
return theta
def blur_blending(im1, im2, mask):
mask = mask * 255.0
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
mask = Image.fromarray(mask.astype("uint8")).convert("L")
im1 = Image.fromarray(im1.astype("uint8"))
im2 = Image.fromarray(im2.astype("uint8"))
mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
im = Image.composite(im1, im2, mask)
im = Image.composite(im, im2, mask_blur)
return np.array(im) / 255.0
def blur_blending_cv2(im1, im2, mask):
mask = mask * 255.0
kernel = np.ones((9, 9), np.uint8)
mask = cv2.erode(mask, kernel, iterations=3)
mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
mask_blur /= 255.0
im = im1 * mask_blur + (1 - mask_blur) * im2
im /= 255.0
im = np.clip(im, 0.0, 1.0)
return im
def Poisson_blending(im1, im2, mask):
# mask = 1 - mask
mask = mask * 255.0
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
mask /= 255
mask = 1 - mask
mask *= 255
mask = mask[:, :, 0]
width, height, channels = im1.shape
center = (int(height / 2), int(width / 2))
result = cv2.seamlessClone(
im2.astype("uint8"),
im1.astype("uint8"),
mask.astype("uint8"),
center,
cv2.MIXED_CLONE,
)
return result / 255.0
def Poisson_B(im1, im2, mask, center):
mask = mask * 255.0
result = cv2.seamlessClone(
im2.astype("uint8"),
im1.astype("uint8"),
mask.astype("uint8"),
center,
cv2.NORMAL_CLONE,
)
return result / 255
def seamless_clone(old_face, new_face, raw_mask):
height, width, _ = old_face.shape
height = height // 2
width = width // 2
y_indices, x_indices, _ = np.nonzero(raw_mask)
y_crop = slice(np.min(y_indices), np.max(y_indices))
x_crop = slice(np.min(x_indices), np.max(x_indices))
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
insertion_mask[insertion_mask != 0] = 255
prior = np.rint(
np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")
).astype("uint8")
# if np.sum(insertion_mask) == 0:
n_mask = insertion_mask[1:-1, 1:-1, :]
n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
print(n_mask.shape)
x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
if w < 4 or h < 4:
blended = prior
else:
blended = cv2.seamlessClone(
insertion, # pylint: disable=no-member
prior,
insertion_mask,
(x_center, y_center),
cv2.NORMAL_CLONE,
) # pylint: disable=no-member
blended = blended[height:-height, width:-width]
return blended.astype("float32") / 255.0
def get_landmark(face_landmarks, id):
part = face_landmarks.part(id)
x = part.x
y = part.y
return (x, y)
def search(face_landmarks):
x1, y1 = get_landmark(face_landmarks, 36)
x2, y2 = get_landmark(face_landmarks, 39)
x3, y3 = get_landmark(face_landmarks, 42)
x4, y4 = get_landmark(face_landmarks, 45)
x_nose, y_nose = get_landmark(face_landmarks, 30)
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
x_left_eye = int((x1 + x2) / 2)
y_left_eye = int((y1 + y2) / 2)
x_right_eye = int((x3 + x4) / 2)
y_right_eye = int((y3 + y4) / 2)
results = np.array(
[
[x_left_eye, y_left_eye],
[x_right_eye, y_right_eye],
[x_nose, y_nose],
[x_left_mouth, y_left_mouth],
[x_right_mouth, y_right_mouth],
]
)
return results
def align_warp(
original_image: Image,
restored_faces: list,
) -> Image:
if len(restored_faces) == 0:
return original_image
face_detector = dlib.get_frontal_face_detector()
landmark = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"shape_predictor_68_face_landmarks.dat",
)
landmark_locator = dlib.shape_predictor(landmark)
origin_width, origin_height = original_image.size
image = np.array(original_image)
faces = face_detector(image)
blended = image
for face_id in range(len(faces)):
current_face = faces[face_id]
face_landmarks = landmark_locator(image, current_face)
current_fl = search(face_landmarks)
forward_mask = np.ones_like(image).astype("uint8")
affine = compute_transformation_matrix(
image, current_fl, False, target_face_scale=1.3
)
aligned_face = warp(
image, affine, output_shape=(256, 256, 3), preserve_range=True
)
forward_mask = warp(
forward_mask,
affine,
output_shape=(256, 256, 3),
order=0,
preserve_range=True,
)
affine_inverse = affine.inverse
cur_face = np.array(restored_faces[face_id])
## Histogram Color matching
A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
B = match_histograms(B, A)
cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
warped_back = warp(
cur_face,
affine_inverse,
output_shape=(origin_height, origin_width, 3),
order=3,
preserve_range=True,
)
backward_mask = warp(
forward_mask,
affine_inverse,
output_shape=(origin_height, origin_width, 3),
order=0,
preserve_range=True,
) ## Nearest neighbour
blended = blur_blending_cv2(warped_back, blended, backward_mask)
blended *= 255.0
return Image.fromarray(img_as_ubyte(blended / 255.0))

View File

@@ -1,430 +0,0 @@
# Copyright (c) Microsoft Corporation
from skimage.transform import SimilarityTransform
from matplotlib.patches import Rectangle
from PIL import Image, ImageFilter
from skimage.transform import warp
from skimage import img_as_ubyte
import matplotlib.pyplot as plt
import numpy as np
import dlib
import cv2
import os
def calculate_cdf(histogram):
"""
This method calculates the cumulative distribution function
:param array histogram: The values of the histogram
:return: normalized_cdf: The normalized cumulative distribution function
:rtype: array
"""
# Get the cumulative sum of the elements
cdf = histogram.cumsum()
# Normalize the cdf
normalized_cdf = cdf / float(cdf.max())
return normalized_cdf
def calculate_lookup(src_cdf, ref_cdf):
"""
This method creates the lookup table
:param array src_cdf: The cdf for the source image
:param array ref_cdf: The cdf for the reference image
:return: lookup_table: The lookup table
:rtype: array
"""
lookup_table = np.zeros(256)
lookup_val = 0
for src_pixel_val in range(len(src_cdf)):
lookup_val
for ref_pixel_val in range(len(ref_cdf)):
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
lookup_val = ref_pixel_val
break
lookup_table[src_pixel_val] = lookup_val
return lookup_table
def match_histograms(src_image, ref_image):
"""
This method matches the source image histogram to the
reference signal
:param image src_image: The original source image
:param image ref_image: The reference image
:return: image_after_matching
:rtype: image (array)
"""
# Split the images into the different color channels
# b means blue, g means green and r means red
src_b, src_g, src_r = cv2.split(src_image)
ref_b, ref_g, ref_r = cv2.split(ref_image)
# Compute the b, g, and r histograms separately
# The flatten() Numpy method returns a copy of the array c
# collapsed into one dimension.
src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
# Compute the normalized cdf for the source and reference image
src_cdf_blue = calculate_cdf(src_hist_blue)
src_cdf_green = calculate_cdf(src_hist_green)
src_cdf_red = calculate_cdf(src_hist_red)
ref_cdf_blue = calculate_cdf(ref_hist_blue)
ref_cdf_green = calculate_cdf(ref_hist_green)
ref_cdf_red = calculate_cdf(ref_hist_red)
# Make a separate lookup table for each color
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
# Use the lookup function to transform the colors of the original
# source image
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
green_after_transform = cv2.LUT(src_g, green_lookup_table)
red_after_transform = cv2.LUT(src_r, red_lookup_table)
# Put the image back together
image_after_matching = cv2.merge(
[blue_after_transform, green_after_transform, red_after_transform]
)
image_after_matching = cv2.convertScaleAbs(image_after_matching)
return image_after_matching
def _standard_face_pts():
pts = (
np.array(
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
np.float32,
)
/ 256.0
- 1.0
)
return np.reshape(pts, (5, 2))
def _origin_face_pts():
pts = np.array(
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
np.float32,
)
return np.reshape(pts, (5, 2))
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(target_pts, landmark)
return affine
def compute_inverse_transformation_matrix(
img, landmark, normalize, target_face_scale=1.0
):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(landmark, target_pts)
return affine
def show_detection(image, box, landmark):
plt.imshow(image)
print(box[2] - box[0])
plt.gca().add_patch(
Rectangle(
(box[1], box[0]),
box[2] - box[0],
box[3] - box[1],
linewidth=1,
edgecolor="r",
facecolor="none",
)
)
plt.scatter(landmark[0][0], landmark[0][1])
plt.scatter(landmark[1][0], landmark[1][1])
plt.scatter(landmark[2][0], landmark[2][1])
plt.scatter(landmark[3][0], landmark[3][1])
plt.scatter(landmark[4][0], landmark[4][1])
plt.show()
def affine2theta(affine, input_w, input_h, target_w, target_h):
# param = np.linalg.inv(affine)
param = affine
theta = np.zeros([2, 3])
theta[0, 0] = param[0, 0] * input_h / target_h
theta[0, 1] = param[0, 1] * input_w / target_h
theta[0, 2] = (
2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w
) / target_h - 1
theta[1, 0] = param[1, 0] * input_h / target_w
theta[1, 1] = param[1, 1] * input_w / target_w
theta[1, 2] = (
2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w
) / target_w - 1
return theta
def blur_blending(im1, im2, mask):
mask = mask * 255.0
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
mask = Image.fromarray(mask.astype("uint8")).convert("L")
im1 = Image.fromarray(im1.astype("uint8"))
im2 = Image.fromarray(im2.astype("uint8"))
mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
im = Image.composite(im1, im2, mask)
im = Image.composite(im, im2, mask_blur)
return np.array(im) / 255.0
def blur_blending_cv2(im1, im2, mask):
mask = mask * 255.0
kernel = np.ones((9, 9), np.uint8)
mask = cv2.erode(mask, kernel, iterations=3)
mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
mask_blur /= 255.0
im = im1 * mask_blur + (1 - mask_blur) * im2
im /= 255.0
im = np.clip(im, 0.0, 1.0)
return im
def Poisson_blending(im1, im2, mask):
# mask = 1 - mask
mask = mask * 255.0
kernel = np.ones((10, 10), np.uint8)
mask = cv2.erode(mask, kernel, iterations=1)
mask /= 255
mask = 1 - mask
mask *= 255
mask = mask[:, :, 0]
width, height, channels = im1.shape
center = (int(height / 2), int(width / 2))
result = cv2.seamlessClone(
im2.astype("uint8"),
im1.astype("uint8"),
mask.astype("uint8"),
center,
cv2.MIXED_CLONE,
)
return result / 255.0
def Poisson_B(im1, im2, mask, center):
mask = mask * 255.0
result = cv2.seamlessClone(
im2.astype("uint8"),
im1.astype("uint8"),
mask.astype("uint8"),
center,
cv2.NORMAL_CLONE,
)
return result / 255
def seamless_clone(old_face, new_face, raw_mask):
height, width, _ = old_face.shape
height = height // 2
width = width // 2
y_indices, x_indices, _ = np.nonzero(raw_mask)
y_crop = slice(np.min(y_indices), np.max(y_indices))
x_crop = slice(np.min(x_indices), np.max(x_indices))
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
insertion_mask[insertion_mask != 0] = 255
prior = np.rint(
np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")
).astype("uint8")
# if np.sum(insertion_mask) == 0:
n_mask = insertion_mask[1:-1, 1:-1, :]
n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
print(n_mask.shape)
x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
if w < 4 or h < 4:
blended = prior
else:
blended = cv2.seamlessClone(
insertion, # pylint: disable=no-member
prior,
insertion_mask,
(x_center, y_center),
cv2.NORMAL_CLONE,
) # pylint: disable=no-member
blended = blended[height:-height, width:-width]
return blended.astype("float32") / 255.0
def get_landmark(face_landmarks, id):
part = face_landmarks.part(id)
x = part.x
y = part.y
return (x, y)
def search(face_landmarks):
x1, y1 = get_landmark(face_landmarks, 36)
x2, y2 = get_landmark(face_landmarks, 39)
x3, y3 = get_landmark(face_landmarks, 42)
x4, y4 = get_landmark(face_landmarks, 45)
x_nose, y_nose = get_landmark(face_landmarks, 30)
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
x_left_eye = int((x1 + x2) / 2)
y_left_eye = int((y1 + y2) / 2)
x_right_eye = int((x3 + x4) / 2)
y_right_eye = int((y3 + y4) / 2)
results = np.array(
[
[x_left_eye, y_left_eye],
[x_right_eye, y_right_eye],
[x_nose, y_nose],
[x_left_mouth, y_left_mouth],
[x_right_mouth, y_right_mouth],
]
)
return results
def align_warp_hr(
original_image: Image,
restored_faces: list,
) -> Image:
if len(restored_faces) == 0:
return original_image
face_detector = dlib.get_frontal_face_detector()
landmark = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"shape_predictor_68_face_landmarks.dat",
)
landmark_locator = dlib.shape_predictor(landmark)
origin_width, origin_height = original_image.size
image = np.array(original_image)
faces = face_detector(image)
blended = image
for face_id in range(len(faces)):
current_face = faces[face_id]
face_landmarks = landmark_locator(image, current_face)
current_fl = search(face_landmarks)
forward_mask = np.ones_like(image).astype("uint8")
affine = compute_transformation_matrix(
image, current_fl, False, target_face_scale=1.3
)
aligned_face = warp(
image, affine, output_shape=(512, 512, 3), preserve_range=True
)
forward_mask = warp(
forward_mask,
affine,
output_shape=(512, 512, 3),
order=0,
preserve_range=True,
)
affine_inverse = affine.inverse
cur_face = np.array(restored_faces[face_id])
## Histogram Color matching
A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
B = match_histograms(B, A)
cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
warped_back = warp(
cur_face,
affine_inverse,
output_shape=(origin_height, origin_width, 3),
order=3,
preserve_range=True,
)
backward_mask = warp(
forward_mask,
affine_inverse,
output_shape=(origin_height, origin_width, 3),
order=0,
preserve_range=True,
) ## Nearest neighbour
blended = blur_blending_cv2(warped_back, blended, backward_mask)
blended *= 255.0
return Image.fromarray(img_as_ubyte(blended / 255.0))

View File

@@ -1,153 +0,0 @@
# Copyright (c) Microsoft Corporation
from skimage.transform import SimilarityTransform
from matplotlib.patches import Rectangle
from skimage.transform import warp
from skimage import img_as_ubyte
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import dlib
import os
def _standard_face_pts():
pts = (
np.array(
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
np.float32,
)
/ 256.0
- 1.0
)
return np.reshape(pts, (5, 2))
def get_landmark(face_landmarks, id):
part = face_landmarks.part(id)
x = part.x
y = part.y
return (x, y)
def search(face_landmarks):
x1, y1 = get_landmark(face_landmarks, 36)
x2, y2 = get_landmark(face_landmarks, 39)
x3, y3 = get_landmark(face_landmarks, 42)
x4, y4 = get_landmark(face_landmarks, 45)
x_nose, y_nose = get_landmark(face_landmarks, 30)
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
x_left_eye = int((x1 + x2) / 2)
y_left_eye = int((y1 + y2) / 2)
x_right_eye = int((x3 + x4) / 2)
y_right_eye = int((y3 + y4) / 2)
results = np.array(
[
[x_left_eye, y_left_eye],
[x_right_eye, y_right_eye],
[x_nose, y_nose],
[x_left_mouth, y_left_mouth],
[x_right_mouth, y_right_mouth],
]
)
return results
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(target_pts, landmark)
return affine.params
def show_detection(image, box, landmark):
plt.imshow(image)
print(box[2] - box[0])
plt.gca().add_patch(
Rectangle(
(box[1], box[0]),
box[2] - box[0],
box[3] - box[1],
linewidth=1,
edgecolor="r",
facecolor="none",
)
)
plt.scatter(landmark[0][0], landmark[0][1])
plt.scatter(landmark[1][0], landmark[1][1])
plt.scatter(landmark[2][0], landmark[2][1])
plt.scatter(landmark[3][0], landmark[3][1])
plt.scatter(landmark[4][0], landmark[4][1])
plt.show()
def affine2theta(affine, input_w, input_h, target_w, target_h):
# param = np.linalg.inv(affine)
param = affine
theta = np.zeros([2, 3])
theta[0, 0] = param[0, 0] * input_h / target_h
theta[0, 1] = param[0, 1] * input_w / target_h
theta[0, 2] = (
2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w
) / target_h - 1
theta[1, 0] = param[1, 0] * input_h / target_w
theta[1, 1] = param[1, 1] * input_w / target_w
theta[1, 2] = (
2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w
) / target_w - 1
return theta
def detect(input_image: Image) -> list:
face_detector = dlib.get_frontal_face_detector()
detected_faces = []
landmark = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"shape_predictor_68_face_landmarks.dat",
)
landmark_locator = dlib.shape_predictor(landmark)
image = np.array(input_image)
faces = face_detector(image)
if len(faces) == 0:
return detected_faces
for face_id in range(len(faces)):
current_face = faces[face_id]
face_landmarks = landmark_locator(image, current_face)
current_fl = search(face_landmarks)
affine = compute_transformation_matrix(
image, current_fl, False, target_face_scale=1.3
)
aligned_face = warp(image, affine, output_shape=(256, 256, 3))
detected_faces.append(Image.fromarray(img_as_ubyte(aligned_face)))
return detected_faces

View File

@@ -1,153 +0,0 @@
# Copyright (c) Microsoft Corporation
from skimage.transform import SimilarityTransform
from matplotlib.patches import Rectangle
from skimage.transform import warp
from skimage import img_as_ubyte
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import dlib
import os
def _standard_face_pts():
pts = (
np.array(
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
np.float32,
)
/ 256.0
- 1.0
)
return np.reshape(pts, (5, 2))
def get_landmark(face_landmarks, id):
part = face_landmarks.part(id)
x = part.x
y = part.y
return (x, y)
def search(face_landmarks):
x1, y1 = get_landmark(face_landmarks, 36)
x2, y2 = get_landmark(face_landmarks, 39)
x3, y3 = get_landmark(face_landmarks, 42)
x4, y4 = get_landmark(face_landmarks, 45)
x_nose, y_nose = get_landmark(face_landmarks, 30)
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
x_left_eye = int((x1 + x2) / 2)
y_left_eye = int((y1 + y2) / 2)
x_right_eye = int((x3 + x4) / 2)
y_right_eye = int((y3 + y4) / 2)
results = np.array(
[
[x_left_eye, y_left_eye],
[x_right_eye, y_right_eye],
[x_nose, y_nose],
[x_left_mouth, y_left_mouth],
[x_right_mouth, y_right_mouth],
]
)
return results
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
std_pts = _standard_face_pts() # [-1,1]
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
# print(target_pts)
h, w, c = img.shape
if normalize == True:
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
# print(landmark)
affine = SimilarityTransform()
affine.estimate(target_pts, landmark)
return affine.params
def show_detection(image, box, landmark):
plt.imshow(image)
print(box[2] - box[0])
plt.gca().add_patch(
Rectangle(
(box[1], box[0]),
box[2] - box[0],
box[3] - box[1],
linewidth=1,
edgecolor="r",
facecolor="none",
)
)
plt.scatter(landmark[0][0], landmark[0][1])
plt.scatter(landmark[1][0], landmark[1][1])
plt.scatter(landmark[2][0], landmark[2][1])
plt.scatter(landmark[3][0], landmark[3][1])
plt.scatter(landmark[4][0], landmark[4][1])
plt.show()
def affine2theta(affine, input_w, input_h, target_w, target_h):
# param = np.linalg.inv(affine)
param = affine
theta = np.zeros([2, 3])
theta[0, 0] = param[0, 0] * input_h / target_h
theta[0, 1] = param[0, 1] * input_w / target_h
theta[0, 2] = (
2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w
) / target_h - 1
theta[1, 0] = param[1, 0] * input_h / target_w
theta[1, 1] = param[1, 1] * input_w / target_w
theta[1, 2] = (
2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w
) / target_w - 1
return theta
def detect_hr(input_image: Image) -> list:
face_detector = dlib.get_frontal_face_detector()
detected_faces = []
landmark = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"shape_predictor_68_face_landmarks.dat",
)
landmark_locator = dlib.shape_predictor(landmark)
image = np.array(input_image)
faces = face_detector(image)
if len(faces) == 0:
return detected_faces
for face_id in range(len(faces)):
current_face = faces[face_id]
face_landmarks = landmark_locator(image, current_face)
current_fl = search(face_landmarks)
affine = compute_transformation_matrix(
image, current_fl, False, target_face_scale=1.3
)
aligned_face = warp(image, affine, output_shape=(512, 512, 3))
detected_faces.append(Image.fromarray(img_as_ubyte(aligned_face)))
return detected_faces

View File

@@ -1,24 +0,0 @@
# Copyright (c) Microsoft Corporation
import torch.utils.data
from .face_dataset import FaceTestDataset
def create_dataloader(opt, faces: list):
instance = FaceTestDataset()
instance.initialize(opt, faces)
print(
f"Dataset [{type(instance).__name__}] with {len(instance)} faces was created..."
)
dataloader = torch.utils.data.DataLoader(
instance,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
drop_last=opt.isTrain,
)
return dataloader

View File

@@ -1,125 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess_mode == "resize_and_crop":
new_h = new_w = opt.load_size
elif opt.preprocess_mode == "scale_width_and_crop":
new_w = opt.load_size
new_h = opt.load_size * h // w
elif opt.preprocess_mode == "scale_shortside_and_crop":
ss, ls = min(w, h), max(w, h) # shortside and longside
width_is_shorter = w == ss
ls = int(opt.load_size * ls / ss)
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
flip = random.random() > 0.5
return {"crop_pos": (x, y), "flip": flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
transform_list = []
if "resize" in opt.preprocess_mode:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, interpolation=method))
elif "scale_width" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
elif "scale_shortside" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
if "crop" in opt.preprocess_mode:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size)))
if opt.preprocess_mode == "none":
base = 32
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.preprocess_mode == "fixed":
w = opt.crop_size
h = round(opt.crop_size / opt.aspect_ratio)
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"])))
if toTensor:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __resize(img, w, h, method=Image.BICUBIC):
return img.resize((w, h), method)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if ow == target_width:
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __scale_shortside(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
width_is_shorter = ow == ss
if ss == target_width:
return img
ls = int(target_width * ls / ss)
nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
return img.resize((nw, nh), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
return img.crop((x1, y1, x1 + tw, y1 + th))
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img

View File

@@ -1,56 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .pix2pix_dataset import Pix2pixDataset
from .image_folder import make_dataset
class CustomDataset(Pix2pixDataset):
""" Dataset that loads images from directories
Use option --label_dir, --image_dir, --instance_dir to specify the directories.
The images in the directories are sorted in alphabetical order and paired in order.
"""
@staticmethod
def modify_commandline_options(parser, is_train):
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
parser.set_defaults(preprocess_mode="resize_and_crop")
load_size = 286 if is_train else 256
parser.set_defaults(load_size=load_size)
parser.set_defaults(crop_size=256)
parser.set_defaults(display_winsize=256)
parser.set_defaults(label_nc=13)
parser.set_defaults(contain_dontcare_label=False)
parser.add_argument(
"--label_dir", type=str, required=True, help="path to the directory that contains label images"
)
parser.add_argument(
"--image_dir", type=str, required=True, help="path to the directory that contains photo images"
)
parser.add_argument(
"--instance_dir",
type=str,
default="",
help="path to the directory that contains instance maps. Leave black if not exists",
)
return parser
def get_paths(self, opt):
label_dir = opt.label_dir
label_paths = make_dataset(label_dir, recursive=False, read_cache=True)
image_dir = opt.image_dir
image_paths = make_dataset(image_dir, recursive=False, read_cache=True)
if len(opt.instance_dir) > 0:
instance_dir = opt.instance_dir
instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
else:
instance_paths = []
assert len(label_paths) == len(
image_paths
), "The #images in %s and %s do not match. Is there something wrong?"
return label_paths, image_paths, instance_paths

View File

@@ -1,72 +0,0 @@
# Copyright (c) Microsoft Corporation
from .base_dataset import BaseDataset, get_params, get_transform
import torch
class FaceTestDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--no_pairing_check",
action="store_true",
help="If specified, skip sanity check of correct label-image file pairing",
)
# parser.set_defaults(contain_dontcare_label=False)
# parser.set_defaults(no_instance=True)
return parser
def initialize(self, opt, faces: list):
self.opt = opt
self.images = faces # All the images
self.parts = [
"skin",
"hair",
"l_brow",
"r_brow",
"l_eye",
"r_eye",
"eye_g",
"l_ear",
"r_ear",
"ear_r",
"nose",
"mouth",
"u_lip",
"l_lip",
"neck",
"neck_l",
"cloth",
"hat",
]
size = len(self.images)
self.dataset_size = size
def __getitem__(self, index):
params = get_params(self.opt, (-1, -1))
image = self.images[index].convert("RGB")
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
full_label = []
for _ in self.parts:
current_part = torch.zeros((self.opt.load_size, self.opt.load_size))
full_label.append(current_part)
full_label_tensor = torch.stack(full_label, 0)
input_dict = {
"label": full_label_tensor,
"image": image_tensor,
"path": "",
}
return input_dict
def __len__(self):
return self.dataset_size

View File

@@ -1,101 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
".jpg",
".JPG",
".jpeg",
".JPEG",
".png",
".PNG",
".ppm",
".PPM",
".bmp",
".BMP",
".tiff",
".webp",
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset_rec(dir, images):
assert os.path.isdir(dir), "%s is not a valid directory" % dir
for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
images = []
if read_cache:
possible_filelist = os.path.join(dir, "files.list")
if os.path.isfile(possible_filelist):
with open(possible_filelist, "r") as f:
images = f.read().splitlines()
return images
if recursive:
make_dataset_rec(dir, images)
else:
assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
for root, dnames, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
if write_cache:
filelist_cache = os.path.join(dir, "files.list")
with open(filelist_cache, "w") as f:
for path in images:
f.write("%s\n" % path)
print("wrote filelist cache at %s" % filelist_cache)
return images
def default_loader(path):
return Image.open(path).convert("RGB")
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise (
RuntimeError(
"Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
)
)
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)

View File

@@ -1,108 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
from ..util import util
import os
class Pix2pixDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--no_pairing_check",
action="store_true",
help="If specified, skip sanity check of correct label-image file pairing",
)
return parser
def initialize(self, opt):
self.opt = opt
label_paths, image_paths, instance_paths = self.get_paths(opt)
util.natural_sort(label_paths)
util.natural_sort(image_paths)
if not opt.no_instance:
util.natural_sort(instance_paths)
label_paths = label_paths[: opt.max_dataset_size]
image_paths = image_paths[: opt.max_dataset_size]
instance_paths = instance_paths[: opt.max_dataset_size]
if not opt.no_pairing_check:
for path1, path2 in zip(label_paths, image_paths):
assert self.paths_match(path1, path2), (
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this."
% (path1, path2)
)
self.label_paths = label_paths
self.image_paths = image_paths
self.instance_paths = instance_paths
size = len(self.label_paths)
self.dataset_size = size
def get_paths(self, opt):
label_paths = []
image_paths = []
instance_paths = []
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
return label_paths, image_paths, instance_paths
def paths_match(self, path1, path2):
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
return filename1_without_ext == filename2_without_ext
def __getitem__(self, index):
# Label Image
label_path = self.label_paths[index]
label = Image.open(label_path)
params = get_params(self.opt, label.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
label_tensor = transform_label(label) * 255.0
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
# input image (real images)
image_path = self.image_paths[index]
assert self.paths_match(
label_path, image_path
), "The label_path %s and image_path %s don't match." % (label_path, image_path)
image = Image.open(image_path)
image = image.convert("RGB")
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
# if using instance maps
if self.opt.no_instance:
instance_tensor = 0
else:
instance_path = self.instance_paths[index]
instance = Image.open(instance_path)
if instance.mode == "L":
instance_tensor = transform_label(instance) * 255
instance_tensor = instance_tensor.long()
else:
instance_tensor = transform_label(instance)
input_dict = {
"label": label_tensor,
"instance": instance_tensor,
"image": image_tensor,
"path": image_path,
}
# Give subclasses a chance to modify the final output
self.postprocess(input_dict)
return input_dict
def postprocess(self, input_dict):
return input_dict
def __len__(self):
return self.dataset_size

View File

@@ -1,45 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import importlib
import torch
def find_model_using_name(model_name):
# Given the option --model [modelname],
# the file "models/modelname_model.py"
# will be imported.
assert model_name == 'pix2pix'
model_filename = f"{model_name}_model"
from . import pix2pix_model as modellib
# In the file, the class called ModelNameModel() will
# be instantiated. It has to be a subclass of torch.nn.Module,
# and it is case-insensitive.
model = None
target_model_name = model_name.replace("_", "") + "model"
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module):
model = cls
if model is None:
raise SystemError(
"In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase."
% (model_filename, target_model_name)
)
return model
def get_option_setter(model_name):
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % (type(instance).__name__))
return instance

View File

@@ -1,57 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
from .base_network import BaseNetwork
from .generator import *
from .encoder import *
from ...util import util
def find_network_using_name(target_network_name, filename):
target_class_name = target_network_name + filename
module_name = "models.networks." + filename
network = util.find_class_in_module(target_class_name, module_name)
assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network
return network
def modify_commandline_options(parser, is_train):
opt, _ = parser.parse_known_args()
netG_cls = find_network_using_name(opt.netG, "generator")
parser = netG_cls.modify_commandline_options(parser, is_train)
if is_train:
netD_cls = find_network_using_name(opt.netD, "discriminator")
parser = netD_cls.modify_commandline_options(parser, is_train)
netE_cls = find_network_using_name("conv", "encoder")
parser = netE_cls.modify_commandline_options(parser, is_train)
return parser
def create_network(cls, opt):
net = cls(opt)
net.print_network()
if torch.cuda.is_available() and len(opt.gpu_ids) > 0:
net.cuda()
net.init_weights(opt.init_type, opt.init_variance)
return net
def define_G(opt):
netG_cls = find_network_using_name(opt.netG, "generator")
return create_network(netG_cls, opt)
def define_D(opt):
netD_cls = find_network_using_name(opt.netD, "discriminator")
return create_network(netD_cls, opt)
def define_E(opt):
# there exists only one encoder type
netE_cls = find_network_using_name("conv", "encoder")
return create_network(netE_cls, opt)

View File

@@ -1,173 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.nn.utils.spectral_norm as spectral_norm
from .normalization import SPADE
# ResNet block that uses SPADE.
# It differs from the ResNet block of pix2pixHD in that
# it takes in the segmentation map as input, learns the skip connection if necessary,
# and applies normalization first and then convolution.
# This architecture seemed like a standard architecture for unconditional or
# class-conditional GAN architecture using residual block.
# The code was inspired from https://github.com/LMescheder/GAN_stability.
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
self.opt = opt
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, degraded_image):
x_s = self.shortcut(x, seg, degraded_image)
dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))
out = x_s + dx
return out
def shortcut(self, x, seg, degraded_image):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
# ResNet block used in pix2pixHD
# We keep the same architecture as pix2pixHD.
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
super().__init__()
pw = (kernel_size - 1) // 2
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
activation,
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
)
def forward(self, x):
y = self.conv_block(x)
out = x + y
return out
# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class SPADEResnetBlock_non_spade(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
self.opt = opt
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, degraded_image):
x_s = self.shortcut(x, seg, degraded_image)
dx = self.conv_0(self.actvn(x))
dx = self.conv_1(self.actvn(dx))
out = x_s + dx
return out
def shortcut(self, x, seg, degraded_image):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)

View File

@@ -1,58 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
from torch.nn import init
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(
"Network [%s] was created. Total number of parameters: %.1f million."
% (type(self).__name__, num_params / 1000000)
)
def init_weights(self, init_type="normal", gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if classname.find("BatchNorm2d") != -1:
if hasattr(m, "weight") and m.weight is not None:
init.normal_(m.weight.data, 1.0, gain)
if hasattr(m, "bias") and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
if init_type == "normal":
init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "xavier_uniform":
init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == "kaiming":
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
init.orthogonal_(m.weight.data, gain=gain)
elif init_type == "none": # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
if hasattr(m, "bias") and m.bias is not None:
init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, "init_weights"):
m.init_weights(init_type, gain)

View File

@@ -1,53 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from .base_network import BaseNetwork
from .normalization import get_nonspade_norm_layer
class ConvEncoder(BaseNetwork):
""" Same architecture as the image discriminator """
def __init__(self, opt):
super().__init__()
kw = 3
pw = int(np.ceil((kw - 1.0) / 2))
ndf = opt.ngf
norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))
self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))
self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))
self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))
self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
if opt.crop_size >= 256:
self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
self.so = s0 = 4
self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)
self.actvn = nn.LeakyReLU(0.2, False)
self.opt = opt
def forward(self, x):
if x.size(2) != 256 or x.size(3) != 256:
x = F.interpolate(x, size=(256, 256), mode="bilinear")
x = self.layer1(x)
x = self.layer2(self.actvn(x))
x = self.layer3(self.actvn(x))
x = self.layer4(self.actvn(x))
x = self.layer5(self.actvn(x))
if self.opt.crop_size >= 256:
x = self.layer6(self.actvn(x))
x = self.actvn(x)
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logvar = self.fc_var(x)
return mu, logvar

View File

@@ -1,233 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_network import BaseNetwork
from .normalization import get_nonspade_norm_layer
from .architecture import ResnetBlock as ResnetBlock
from .architecture import SPADEResnetBlock as SPADEResnetBlock
from .architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade
class SPADEGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.set_defaults(norm_G="spectralspadesyncbatch3x3")
parser.add_argument(
"--num_upsampling_layers",
choices=("normal", "more", "most"),
default="normal",
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator",
)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
nf = opt.ngf
self.sw, self.sh = self.compute_latent_vector_size(opt)
print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))
if opt.use_vae:
# In case of VAE, we will sample from random z vector
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
else:
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
if self.opt.no_parsing_map:
self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
else:
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
else:
self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
else:
self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
else:
self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
else:
self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
else:
self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)
if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
else:
self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)
final_nc = nf
if opt.num_upsampling_layers == "most":
self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
final_nc = nf // 2
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def compute_latent_vector_size(self, opt):
if opt.num_upsampling_layers == "normal":
num_up_layers = 5
elif opt.num_upsampling_layers == "more":
num_up_layers = 6
elif opt.num_upsampling_layers == "most":
num_up_layers = 7
else:
raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)
sw = opt.load_size // (2 ** num_up_layers)
sh = round(sw / opt.aspect_ratio)
return sw, sh
def forward(self, input, degraded_image, z=None):
seg = input
if self.opt.use_vae:
# we sample z from unit normal and reshape the tensor
if z is None:
z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())
x = self.fc(z)
x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
else:
# we downsample segmap and run convolution
if self.opt.no_parsing_map:
x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear")
else:
x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest")
x = self.fc(x)
x = self.head_0(x, seg, degraded_image)
x = self.up(x)
x = self.G_middle_0(x, seg, degraded_image)
if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most":
x = self.up(x)
x = self.G_middle_1(x, seg, degraded_image)
x = self.up(x)
x = self.up_0(x, seg, degraded_image)
x = self.up(x)
x = self.up_1(x, seg, degraded_image)
x = self.up(x)
x = self.up_2(x, seg, degraded_image)
x = self.up(x)
x = self.up_3(x, seg, degraded_image)
if self.opt.num_upsampling_layers == "most":
x = self.up(x)
x = self.up_4(x, seg, degraded_image)
x = self.conv_img(F.leaky_relu(x, 2e-1))
x = F.tanh(x)
return x
class Pix2PixHDGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG"
)
parser.add_argument(
"--resnet_n_blocks",
type=int,
default=9,
help="number of residual blocks in the global generator network",
)
parser.add_argument(
"--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block"
)
parser.add_argument(
"--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution"
)
# parser.set_defaults(norm_G='instance')
return parser
def __init__(self, opt):
super().__init__()
input_nc = 3
# print("xxxxx")
# print(opt.norm_G)
norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
activation = nn.ReLU(False)
model = []
# initial conv
model += [
nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),
activation,
]
# downsample
mult = 1
for i in range(opt.resnet_n_downsample):
model += [
norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
activation,
]
mult *= 2
# resnet blocks
for i in range(opt.resnet_n_blocks):
model += [
ResnetBlock(
opt.ngf * mult,
norm_layer=norm_layer,
activation=activation,
kernel_size=opt.resnet_kernel_size,
)
]
# upsample
for i in range(opt.resnet_n_downsample):
nc_in = int(opt.ngf * mult)
nc_out = int((opt.ngf * mult) / 2)
model += [
norm_layer(
nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)
),
activation,
]
mult = mult // 2
# final output conv
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
nn.Tanh(),
]
self.model = nn.Sequential(*model)
def forward(self, input, degraded_image, z=None):
return self.model(degraded_image)

View File

@@ -1,100 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from .sync_batchnorm import SynchronizedBatchNorm2d
import torch.nn.utils.spectral_norm as spectral_norm
def get_nonspade_norm_layer(opt, norm_type="instance"):
# helper function to get # output channels of the previous layer
def get_out_channel(layer):
if hasattr(layer, "out_channels"):
return getattr(layer, "out_channels")
return layer.weight.size(0)
# this function will be returned
def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith("spectral"):
layer = spectral_norm(layer)
subnorm_type = norm_type[len("spectral") :]
if subnorm_type == "none" or len(subnorm_type) == 0:
return layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if getattr(layer, "bias", None) is not None:
delattr(layer, "bias")
layer.register_parameter("bias", None)
if subnorm_type == "batch":
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == "sync_batch":
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == "instance":
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError("normalization layer %s is not recognized" % subnorm_type)
return nn.Sequential(layer, norm_layer)
return add_norm_layer
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc, opt):
super().__init__()
assert config_text.startswith("spade")
parsed = re.search("spade(\D+)(\d)x\d", config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
self.opt = opt
if param_free_norm_type == "instance":
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif param_free_norm_type == "syncbatch":
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
elif param_free_norm_type == "batch":
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128
pw = ks // 2
if self.opt.no_parsing_map:
self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
else:
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, segmap, degraded_image):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear")
if self.opt.no_parsing_map:
actv = self.mlp_shared(degraded_face)
else:
actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1))
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out

View File

@@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
# File : __init__.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
from .batchnorm import set_sbn_eps_mode
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .batchnorm import patch_sync_batchnorm, convert_model
from .replicate import DataParallelWithCallback, patch_replication_callback

View File

@@ -1,412 +0,0 @@
# -*- coding: utf-8 -*-
# File : batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import collections
import contextlib
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
try:
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
except ImportError:
ReduceAddCoalesced = Broadcast = None
try:
from jactorch.parallel.comm import SyncMaster
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
except ImportError:
from .comm import SyncMaster
from .replicate import DataParallelWithCallback
__all__ = [
'set_sbn_eps_mode',
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
'patch_sync_batchnorm', 'convert_model'
]
SBN_EPS_MODE = 'clamp'
def set_sbn_eps_mode(mode):
global SBN_EPS_MODE
assert mode in ('clamp', 'plus')
SBN_EPS_MODE = mode
def _sum_ft(tensor):
"""sum over the first and last dimention"""
return tensor.sum(dim=0).sum(dim=-1)
def _unsqueeze_ft(tensor):
"""add new dimensions at the front and the tail"""
return tensor.unsqueeze(0).unsqueeze(-1)
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SynchronizedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
track_running_stats=track_running_stats)
if not self.track_running_stats:
import warnings
warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
if hasattr(torch, 'no_grad'):
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
else:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
if SBN_EPS_MODE == 'clamp':
return mean, bias_var.clamp(self.eps) ** -0.5
elif SBN_EPS_MODE == 'plus':
return mean, (bias_var + self.eps) ** -0.5
else:
raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
mini-batch.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm1d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
of 3d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm2d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
of 4d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm3d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
or Spatio-temporal BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x depth x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
@contextlib.contextmanager
def patch_sync_batchnorm():
import torch.nn as nn
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
nn.BatchNorm1d = SynchronizedBatchNorm1d
nn.BatchNorm2d = SynchronizedBatchNorm2d
nn.BatchNorm3d = SynchronizedBatchNorm3d
yield
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
def convert_model(module):
"""Traverse the input module and its child recursively
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
to SynchronizedBatchNorm*N*d
Args:
module: the input module needs to be convert to SyncBN model
Examples:
>>> import torch.nn as nn
>>> import torchvision
>>> # m is a standard pytorch model
>>> m = torchvision.models.resnet18(True)
>>> m = nn.DataParallel(m)
>>> # after convert, m is using SyncBN
>>> m = convert_model(m)
"""
if isinstance(module, torch.nn.DataParallel):
mod = module.module
mod = convert_model(mod)
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
return mod
mod = module
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.BatchNorm3d],
[SynchronizedBatchNorm1d,
SynchronizedBatchNorm2d,
SynchronizedBatchNorm3d]):
if isinstance(module, pth_module):
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_model(child))
return mod

View File

@@ -1,74 +0,0 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : batchnorm_reimpl.py
# Author : acgtyrant
# Date : 11/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import torch
import torch.nn as nn
import torch.nn.init as init
__all__ = ['BatchNorm2dReimpl']
class BatchNorm2dReimpl(nn.Module):
"""
A re-implementation of batch normalization, used for testing the numerical
stability.
Author: acgtyrant
See also:
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.weight = nn.Parameter(torch.empty(num_features))
self.bias = nn.Parameter(torch.empty(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_running_stats(self):
self.running_mean.zero_()
self.running_var.fill_(1)
def reset_parameters(self):
self.reset_running_stats()
init.uniform_(self.weight)
init.zeros_(self.bias)
def forward(self, input_):
batchsize, channels, height, width = input_.size()
numel = batchsize * height * width
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
sum_ = input_.sum(1)
sum_of_square = input_.pow(2).sum(1)
mean = sum_ / numel
sumvar = sum_of_square - sum_ * mean
self.running_mean = (
(1 - self.momentum) * self.running_mean
+ self.momentum * mean.detach()
)
unbias_var = sumvar / (numel - 1)
self.running_var = (
(1 - self.momentum) * self.running_var
+ self.momentum * unbias_var.detach()
)
bias_var = sumvar / numel
inv_std = 1 / (bias_var + self.eps).pow(0.5)
output = (
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()

View File

@@ -1,137 +0,0 @@
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)

View File

@@ -1,94 +0,0 @@
# -*- coding: utf-8 -*-
# File : replicate.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'execute_replication_callbacks',
'DataParallelWithCallback',
'patch_replication_callback'
]
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
class DataParallelWithCallback(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
def replicate(self, module, device_ids):
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate

View File

@@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# File : unittest.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import unittest
import torch
class TorchTestCase(unittest.TestCase):
def assertTensorClose(self, x, y):
adiff = float((x - y).abs().max())
if (y == 0).all():
rdiff = 'NaN'
else:
rdiff = float((adiff / y).abs().max())
message = (
'Tensor close check failed\n'
'adiff={}\n'
'rdiff={}\n'
).format(adiff, rdiff)
self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)

View File

@@ -1,246 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
from . import networks
from ..util import util
class Pix2PixModel(torch.nn.Module):
@staticmethod
def modify_commandline_options(parser, is_train):
networks.modify_commandline_options(parser, is_train)
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor
self.netG, self.netD, self.netE = self.initialize_networks(opt)
# set loss functions
if opt.isTrain:
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
self.criterionFeat = torch.nn.L1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
if opt.use_vae:
self.KLDLoss = networks.KLDLoss()
# Entry point for all calls involving forward pass
# of deep networks. We used this approach since DataParallel module
# can't parallelize custom functions, we branch to different
# routines based on |mode|.
def forward(self, data, mode):
input_semantics, real_image, degraded_image = self.preprocess_input(data)
if mode == "generator":
g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image)
return g_loss, generated
elif mode == "discriminator":
d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image)
return d_loss
elif mode == "encode_only":
z, mu, logvar = self.encode_z(real_image)
return mu, logvar
elif mode == "inference":
with torch.no_grad():
fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
return fake_image
else:
raise ValueError("|mode| is invalid")
def create_optimizers(self, opt):
G_params = list(self.netG.parameters())
if opt.use_vae:
G_params += list(self.netE.parameters())
if opt.isTrain:
D_params = list(self.netD.parameters())
beta1, beta2 = opt.beta1, opt.beta2
if opt.no_TTUR:
G_lr, D_lr = opt.lr, opt.lr
else:
G_lr, D_lr = opt.lr / 2, opt.lr * 2
optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
return optimizer_G, optimizer_D
def save(self, epoch):
util.save_network(self.netG, "G", epoch, self.opt)
util.save_network(self.netD, "D", epoch, self.opt)
if self.opt.use_vae:
util.save_network(self.netE, "E", epoch, self.opt)
############################################################################
# Private helper methods
############################################################################
def initialize_networks(self, opt):
netG = networks.define_G(opt)
netD = networks.define_D(opt) if opt.isTrain else None
netE = networks.define_E(opt) if opt.use_vae else None
if not opt.isTrain or opt.continue_train:
netG = util.load_network(netG, "G", opt.which_epoch, opt)
if opt.isTrain:
netD = util.load_network(netD, "D", opt.which_epoch, opt)
if opt.use_vae:
netE = util.load_network(netE, "E", opt.which_epoch, opt)
return netG, netD, netE
# preprocess the input, such as moving the tensors to GPUs and
# transforming the label map to one-hot encoding
# |data|: dictionary of the input data
def preprocess_input(self, data):
# move to GPU and change data types
# data['label'] = data['label'].long()
if not self.opt.isTrain:
if self.use_gpu():
data["label"] = data["label"].cuda()
data["image"] = data["image"].cuda()
return data["label"], data["image"], data["image"]
## While testing, the input image is the degraded face
if self.use_gpu():
data["label"] = data["label"].cuda()
data["degraded_image"] = data["degraded_image"].cuda()
data["image"] = data["image"].cuda()
# # create one-hot label map
# label_map = data['label']
# bs, _, h, w = label_map.size()
# nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
# else self.opt.label_nc
# input_label = self.FloatTensor(bs, nc, h, w).zero_()
# input_semantics = input_label.scatter_(1, label_map, 1.0)
return data["label"], data["image"], data["degraded_image"]
def compute_generator_loss(self, input_semantics, degraded_image, real_image):
G_losses = {}
fake_image, KLD_loss = self.generate_fake(
input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae
)
if self.opt.use_vae:
G_losses["KLD"] = KLD_loss
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False)
if not self.opt.no_ganFeat_loss:
num_D = len(pred_fake)
GAN_Feat_loss = self.FloatTensor(1).fill_(0)
for i in range(num_D): # for each discriminator
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs): # for each layer output
unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
G_losses["GAN_Feat"] = GAN_Feat_loss
if not self.opt.no_vgg_loss:
G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg
return G_losses, fake_image
def compute_discriminator_loss(self, input_semantics, degraded_image, real_image):
D_losses = {}
with torch.no_grad():
fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
fake_image = fake_image.detach()
fake_image.requires_grad_()
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True)
D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True)
return D_losses
def encode_z(self, real_image):
mu, logvar = self.netE(real_image)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):
z = None
KLD_loss = None
if self.opt.use_vae:
z, mu, logvar = self.encode_z(real_image)
if compute_kld_loss:
KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
fake_image = self.netG(input_semantics, degraded_image, z=z)
assert (
not compute_kld_loss
) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
return fake_image, KLD_loss
# Given fake and real image, return the prediction of discriminator
# for each fake and real image.
def discriminate(self, input_semantics, fake_image, real_image):
if self.opt.no_parsing_map:
fake_concat = fake_image
real_concat = real_image
else:
fake_concat = torch.cat([input_semantics, fake_image], dim=1)
real_concat = torch.cat([input_semantics, real_image], dim=1)
# In Batch Normalization, the fake and real images are
# recommended to be in the same batch to avoid disparate
# statistics in fake and real images.
# So both fake and real images are fed to D all at once.
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
discriminator_out = self.netD(fake_and_real)
pred_fake, pred_real = self.divide_pred(discriminator_out)
return pred_fake, pred_real
# Take the prediction of fake and real images from the combined batch
def divide_pred(self, pred):
# the prediction contains the intermediate outputs of multiscale GAN,
# so it's usually a list
if type(pred) == list:
fake = []
real = []
for p in pred:
fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
else:
fake = pred[: pred.size(0) // 2]
real = pred[pred.size(0) // 2 :]
return fake, real
def get_edges(self, t):
edge = self.ByteTensor(t.size()).zero_()
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
return edge.float()
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul(std) + mu
def use_gpu(self):
return torch.cuda.is_available() and len(self.opt.gpu_ids) > 0

View File

@@ -1,2 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

View File

@@ -1,347 +0,0 @@
# Copyright (c) Microsoft Corporation
import argparse
import pickle
import torch
import sys
import os
from ..util import util
from .. import models
class BaseOptions:
def __init__(self):
self.initialized = False
def initialize(self, parser):
# experiment specifics
parser.add_argument(
"--name",
type=str,
default="label2coco",
help="name of the experiment. It decides where to store samples and models",
)
parser.add_argument(
"--gpu_ids",
type=str,
default="0",
help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU",
)
parser.add_argument(
"--checkpoints_dir",
type=str,
default="./checkpoints",
help="models are saved here",
)
parser.add_argument(
"--model", type=str, default="pix2pix", help="which model to use"
)
parser.add_argument(
"--norm_G",
type=str,
default="spectralinstance",
help="instance normalization or batch normalization",
)
parser.add_argument(
"--norm_D",
type=str,
default="spectralinstance",
help="instance normalization or batch normalization",
)
parser.add_argument(
"--norm_E",
type=str,
default="spectralinstance",
help="instance normalization or batch normalization",
)
parser.add_argument(
"--phase", type=str, default="train", help="train, val, test, etc"
)
# input/output sizes
parser.add_argument("--batchSize", type=int, default=1, help="input batch size")
parser.add_argument(
"--preprocess_mode",
type=str,
default="scale_width_and_crop",
help="scaling and cropping of images at load time.",
choices=(
"resize_and_crop",
"crop",
"scale_width",
"scale_width_and_crop",
"scale_shortside",
"scale_shortside_and_crop",
"fixed",
"none",
"resize",
),
)
parser.add_argument(
"--load_size",
type=int,
default=1024,
help="Scale images to this size. The final image will be cropped to --crop_size.",
)
parser.add_argument(
"--crop_size",
type=int,
default=512,
help="Crop to the width of crop_size (after initially scaling the images to load_size.)",
)
parser.add_argument(
"--aspect_ratio",
type=float,
default=1.0,
help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio",
)
parser.add_argument(
"--label_nc",
type=int,
default=182,
help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.",
)
parser.add_argument(
"--contain_dontcare_label",
action="store_true",
help="if the label map contains dontcare label (dontcare=255)",
)
parser.add_argument(
"--output_nc", type=int, default=3, help="# of output image channels"
)
# for setting inputs
parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/")
parser.add_argument("--dataset_mode", type=str, default="coco")
parser.add_argument(
"--serial_batches",
action="store_true",
help="if true, takes images in order to make batches, otherwise takes them randomly",
)
parser.add_argument(
"--no_flip",
action="store_true",
help="if specified, do not flip the images for data argumentation",
)
parser.add_argument(
"--nThreads", default=0, type=int, help="# threads for loading data"
)
parser.add_argument(
"--max_dataset_size",
type=int,
default=sys.maxsize,
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
)
parser.add_argument(
"--load_from_opt_file",
action="store_true",
help="load the options from checkpoints and use that as default",
)
parser.add_argument(
"--cache_filelist_write",
action="store_true",
help="saves the current filelist into a text file, so that it loads faster",
)
parser.add_argument(
"--cache_filelist_read",
action="store_true",
help="reads from the file list cache",
)
# for displays
parser.add_argument(
"--display_winsize", type=int, default=400, help="display window size"
)
# for generator
parser.add_argument(
"--netG",
type=str,
default="spade",
help="selects model to use for netG (pix2pixhd | spade)",
)
parser.add_argument(
"--ngf", type=int, default=64, help="# of gen filters in first conv layer"
)
parser.add_argument(
"--init_type",
type=str,
default="xavier",
help="network initialization [normal|xavier|kaiming|orthogonal]",
)
parser.add_argument(
"--init_variance",
type=float,
default=0.02,
help="variance of the initialization distribution",
)
parser.add_argument(
"--z_dim", type=int, default=256, help="dimension of the latent z vector"
)
parser.add_argument(
"--no_parsing_map",
action="store_true",
help="During training, we do not use the parsing map",
)
# for instance-wise features
parser.add_argument(
"--no_instance",
action="store_true",
help="if specified, do *not* add instance map as input",
)
parser.add_argument(
"--nef",
type=int,
default=16,
help="# of encoder filters in the first conv layer",
)
parser.add_argument(
"--use_vae",
action="store_true",
help="enable training with an image encoder.",
)
parser.add_argument(
"--tensorboard_log",
action="store_true",
help="use tensorboard to record the resutls",
)
# parser.add_argument('--img_dir',)
parser.add_argument(
"--old_face_folder",
type=str,
default="",
help="The folder name of input old face",
)
parser.add_argument(
"--old_face_label_folder",
type=str,
default="",
help="The folder name of input old face label",
)
parser.add_argument("--injection_layer", type=str, default="all", help="")
self.initialized = True
return parser
def gather_options(self):
# initialize parser with basic options
if not self.initialized:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser = self.initialize(parser)
# get the basic options
opt, unknown = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
# modify dataset-related parser options
# dataset_mode = opt.dataset_mode
# dataset_option_setter = data.get_option_setter(dataset_mode)
# parser = dataset_option_setter(parser, self.isTrain)
opt, unknown = parser.parse_known_args()
# if there is opt_file, load it.
# The previous default options will be overwritten
if opt.load_from_opt_file:
parser = self.update_options_from_file(parser, opt)
opt, unknown = parser.parse_known_args()
self.parser = parser
return opt
def print_options(self, opt):
message = ""
message += "----------------- Options ---------------\n"
for k, v in sorted(vars(opt).items()):
comment = ""
default = self.parser.get_default(k)
if v != default:
comment = "\t[default: %s]" % str(default)
message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
message += "----------------- End -------------------"
# print(message)
def option_file_path(self, opt, makedir=False):
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
if makedir:
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, "opt")
return file_name
def save_options(self, opt):
file_name = self.option_file_path(opt, makedir=True)
with open(file_name + ".txt", "wt") as opt_file:
for k, v in sorted(vars(opt).items()):
comment = ""
default = self.parser.get_default(k)
if v != default:
comment = "\t[default: %s]" % str(default)
opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment))
with open(file_name + ".pkl", "wb") as opt_file:
pickle.dump(opt, opt_file)
def update_options_from_file(self, parser, opt):
new_opt = self.load_options(opt)
for k, v in sorted(vars(opt).items()):
if hasattr(new_opt, k) and v != getattr(new_opt, k):
new_val = getattr(new_opt, k)
parser.set_defaults(**{k: new_val})
return parser
def load_options(self, opt):
file_name = self.option_file_path(opt, makedir=False)
new_opt = pickle.load(open(file_name + ".pkl", "rb"))
return new_opt
def parse(self):
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
opt.contain_dontcare_label = False
self.print_options(opt)
if opt.isTrain:
self.save_options(opt)
# Set semantic_nc based on the option.
# This will be convenient in many places
opt.semantic_nc = (
opt.label_nc
+ (1 if opt.contain_dontcare_label else 0)
+ (0 if opt.no_instance else 1)
)
str_ids = opt.gpu_ids.split(",")
opt.gpu_ids = []
for str_id in str_ids:
int_id = int(str_id)
if int_id >= 0:
opt.gpu_ids.append(int_id)
# set gpu ids
if torch.cuda.is_available() and len(opt.gpu_ids) > 0:
if torch.cuda.device_count() > opt.gpu_ids[0]:
try:
torch.cuda.set_device(opt.gpu_ids[0])
except:
print("Failed to set GPU device. Using CPU...")
else:
print("Invalid GPU ID. Using CPU...")
# assert (len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids))
self.opt = opt
return self.opt

View File

@@ -1,26 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self, parser):
BaseOptions.initialize(self, parser)
parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.")
parser.add_argument(
"--which_epoch",
type=str,
default="latest",
help="which epoch to load? set to latest to use latest cached model",
)
parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run")
parser.set_defaults(
preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256
)
parser.set_defaults(serial_batches=True)
parser.set_defaults(no_flip=True)
parser.set_defaults(phase="test")
self.isTrain = False
return parser

View File

@@ -1,34 +0,0 @@
# Copyright (c) Microsoft Corporation
import torchvision.transforms as T
import warnings
from .options.test_options import TestOptions
from .models.pix2pix_model import Pix2PixModel
from .data import create_dataloader
warnings.filterwarnings("ignore", category=UserWarning)
tensor2image = T.ToPILImage()
def test_face(face_images: list, args: dict) -> list:
opt = TestOptions().parse()
for K, V in args.items():
setattr(opt, K, V)
dataloader = create_dataloader(opt, face_images)
images = []
model = Pix2PixModel(opt)
model.eval()
for i, data_i in enumerate(dataloader):
if i * opt.batchSize >= opt.how_many:
break
generated = model(data_i, mode="inference")
for b in range(generated.shape[0]):
images.append(tensor2image((generated[b] + 1) / 2))
return images

View File

@@ -1,2 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

View File

@@ -1,74 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import time
import numpy as np
# Helper class that keeps track of training iterations
class IterationCounter:
def __init__(self, opt, dataset_size):
self.opt = opt
self.dataset_size = dataset_size
self.first_epoch = 1
self.total_epochs = opt.niter + opt.niter_decay
self.epoch_iter = 0 # iter number within each epoch
self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt")
if opt.isTrain and opt.continue_train:
try:
self.first_epoch, self.epoch_iter = np.loadtxt(
self.iter_record_path, delimiter=",", dtype=int
)
print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter))
except:
print(
"Could not load iteration record at %s. Starting from beginning." % self.iter_record_path
)
self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter
# return the iterator of epochs for the training
def training_epochs(self):
return range(self.first_epoch, self.total_epochs + 1)
def record_epoch_start(self, epoch):
self.epoch_start_time = time.time()
self.epoch_iter = 0
self.last_iter_time = time.time()
self.current_epoch = epoch
def record_one_iteration(self):
current_time = time.time()
# the last remaining batch is dropped (see data/__init__.py),
# so we can assume batch size is always opt.batchSize
self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
self.last_iter_time = current_time
self.total_steps_so_far += self.opt.batchSize
self.epoch_iter += self.opt.batchSize
def record_epoch_end(self):
current_time = time.time()
self.time_per_epoch = current_time - self.epoch_start_time
print(
"End of epoch %d / %d \t Time Taken: %d sec"
% (self.current_epoch, self.total_epochs, self.time_per_epoch)
)
if self.current_epoch % self.opt.save_epoch_freq == 0:
np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d")
print("Saved current iteration count at %s." % self.iter_record_path)
def record_current_iter(self):
np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d")
print("Saved current iteration count at %s." % self.iter_record_path)
def needs_saving(self):
return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
def needs_printing(self):
return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
def needs_displaying(self):
return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize

View File

@@ -1,217 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import importlib
import torch
from argparse import Namespace
import numpy as np
from PIL import Image
import os
import argparse
import dill as pickle
def save_obj(obj, name):
with open(name, "wb") as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_obj(name):
with open(name, "rb") as f:
return pickle.load(f)
def copyconf(default_opt, **kwargs):
conf = argparse.Namespace(**vars(default_opt))
for key in kwargs:
print(key, kwargs[key])
setattr(conf, key, kwargs[key])
return conf
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
if image_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(image_tensor.size(0)):
one_image = image_tensor[b]
one_image_np = tensor2im(one_image)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
return images_np
if image_tensor.dim() == 2:
image_tensor = image_tensor.unsqueeze(0)
image_numpy = image_tensor.detach().cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1:
image_numpy = image_numpy[:, :, 0]
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
if label_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(label_tensor.size(0)):
one_image = label_tensor[b]
one_image_np = tensor2label(one_image, n_label, imtype)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
# if tile:
# images_tiled = tile_images(images_np)
# return images_tiled
# else:
# images_np = images_np[0]
# return images_np
return images_np
if label_tensor.dim() == 1:
return np.zeros((64, 64, 3), dtype=np.uint8)
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
result = label_numpy.astype(imtype)
return result
def save_image(image_numpy, image_path, create_dir=False):
if create_dir:
os.makedirs(os.path.dirname(image_path), exist_ok=True)
if len(image_numpy.shape) == 2:
image_numpy = np.expand_dims(image_numpy, axis=2)
if image_numpy.shape[2] == 1:
image_numpy = np.repeat(image_numpy, 3, 2)
image_pil = Image.fromarray(image_numpy)
# save to png
image_pil.save(image_path.replace(".jpg", ".png"))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
"""
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
"""
return [atoi(c) for c in re.split("(\d+)", text)]
def natural_sort(items):
items.sort(key=natural_keys)
def str2bool(v):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace("_", "").lower()
if module == 'models.networks.generator':
from ..models.networks import generator as clslib
elif module == 'models.networks.encoder':
from ..models.networks import encoder as clslib
else:
raise NotImplementedError(f'Module: {module}')
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
if cls is None:
print(
"In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
% (module, target_cls_name)
)
exit(0)
return cls
def save_network(net, label, epoch, opt):
save_filename = "%s_net_%s.pth" % (epoch, label)
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
torch.save(net.cpu().state_dict(), save_path)
if len(opt.gpu_ids) and torch.cuda.is_available():
net.cuda()
def load_network(net, label, epoch, opt):
save_filename = "%s_net_%s.pth" % (epoch, label)
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
save_path = os.path.join(save_dir, save_filename)
if os.path.exists(save_path):
weights = torch.load(save_path)
net.load_state_dict(weights)
return net
###############################################################################
# Code from
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
# Modified so it complies with the Citscape label map colors
###############################################################################
def uint82bin(n, count=8):
"""returns the binary of integer n, count refers to amount of bits"""
return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
class Colorize(object):
def __init__(self, n=35):
self.cmap = labelcolormap(n)
self.cmap = torch.from_numpy(self.cmap[:n])
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = (label == gray_image[0]).cpu()
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image

View File

@@ -1,134 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import ntpath
import time
from . import util
import scipy.misc
try:
from io import BytesIO # Python 3.x
except ImportError:
from StringIO import StringIO # Python 2.7
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch
import numpy as np
class Visualizer:
def __init__(self, opt):
self.opt = opt
self.tf_log = opt.isTrain and opt.tf_log
self.tensorboard_log = opt.tensorboard_log
self.win_size = opt.display_winsize
self.name = opt.name
if self.tensorboard_log:
if self.opt.isTrain:
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.writer = SummaryWriter(log_dir=self.log_dir)
else:
# print("hi :)")
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
if opt.isTrain:
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write("================ Training Loss (%s) ================\n" % now)
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, step):
all_tensor = []
if self.tensorboard_log:
for key, tensor in visuals.items():
all_tensor.append((tensor.data.cpu() + 1) / 2)
output = torch.cat(all_tensor, 0)
img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)
if self.opt.isTrain:
self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
else:
vutils.save_image(
output,
os.path.join(self.log_dir, str(step) + ".png"),
nrow=self.opt.batchSize,
padding=0,
normalize=False,
)
# errors: dictionary of error labels and values
def plot_current_errors(self, errors, step):
if self.tf_log:
for tag, value in errors.items():
value = value.mean().float()
summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
if self.tensorboard_log:
self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
self.writer.add_scalars(
"Loss/GAN",
{
"G": errors["GAN"].mean().float(),
"D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
},
step,
)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t):
message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
for k, v in errors.items():
v = v.mean().float()
message += "%s: %.3f " % (k, v)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write("%s\n" % message)
def convert_visuals_to_numpy(self, visuals):
for key, t in visuals.items():
tile = self.opt.batchSize > 8
if "input_label" == key:
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy
else:
t = util.tensor2im(t, tile=tile)
visuals[key] = t
return visuals
# save image to the disk
def save_images(self, webpage, visuals, image_path):
visuals = self.convert_visuals_to_numpy(visuals)
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
image_name = os.path.join(label, "%s.png" % (name))
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path, create_dir=True)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)

View File

@@ -1,63 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import struct
from PIL import Image
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
#print(fname)
path = os.path.join(root, fname)
images.append(path)
return images
### Modify these 3 lines in your own environment
indir="/home/ziyuwan/workspace/data/temp_old"
target_folders=['VOC','Real_L_old','Real_RGB_old']
out_dir ="/home/ziyuwan/workspace/data/temp_old"
###
if os.path.exists(out_dir) is False:
os.makedirs(out_dir)
#
for target_folder in target_folders:
curr_indir = os.path.join(indir, target_folder)
curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder)))
image_lists = make_dataset(curr_indir)
image_lists.sort()
with open(curr_out_file, 'wb') as wfid:
# write total image number
wfid.write(struct.pack('i', len(image_lists)))
for i, img_path in enumerate(image_lists):
# write file name first
img_name = os.path.basename(img_path)
img_name_bytes = img_name.encode('utf-8')
wfid.write(struct.pack('i', len(img_name_bytes)))
wfid.write(img_name_bytes)
#
# # write image data in
with open(img_path, 'rb') as img_fid:
img_bytes = img_fid.read()
wfid.write(struct.pack('i', len(img_bytes)))
wfid.write(img_bytes)
if i % 1000 == 0:
print('write %d images done' % i)

View File

@@ -1,42 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import io
import os
import struct
from PIL import Image
class BigFileMemoryLoader(object):
def __load_bigfile(self):
print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))
with open(self.file_path, 'rb') as fid:
self.img_num = struct.unpack('i', fid.read(4))[0]
self.img_names = []
self.img_bytes = []
print('find total %d images' % self.img_num)
for i in range(self.img_num):
img_name_len = struct.unpack('i', fid.read(4))[0]
img_name = fid.read(img_name_len).decode('utf-8')
self.img_names.append(img_name)
img_bytes_len = struct.unpack('i', fid.read(4))[0]
self.img_bytes.append(fid.read(img_bytes_len))
if i % 5000 == 0:
print('load %d images done' % i)
print('load all %d images done' % self.img_num)
def __init__(self, file_path):
super(BigFileMemoryLoader, self).__init__()
self.file_path = file_path
self.__load_bigfile()
def __getitem__(self, index):
try:
img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')
return self.img_names[index], img
except Exception:
print('Image read error for index %d: %s' % (index, self.img_names[index]))
return self.__getitem__((index+1)%self.img_num)
def __len__(self):
return self.img_num

View File

@@ -1,16 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None

View File

@@ -1,114 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256
if w<h:
new_w = opt.loadSize
new_h = opt.loadSize * h // w
else:
new_h=opt.loadSize
new_w = opt.loadSize * w // h
if opt.resize_or_crop=='crop_only':
pass
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
# transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it.
transform_list.append(transforms.Scale(256,method))
if 'crop' in opt.resize_or_crop:
if opt.isTrain:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
else:
if opt.test_random_crop:
transform_list.append(transforms.RandomCrop(opt.fineSize))
else:
transform_list.append(transforms.CenterCrop(opt.fineSize))
## when testing, for ablation study, choose center_crop directly.
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img

View File

@@ -1,41 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data
import random
from data.base_data_loader import BaseDataLoader
from data import online_dataset_for_old_photos as dts_ray_bigfile
def CreateDataset(opt):
dataset = None
if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':
dataset = dts_ray_bigfile.UnPairOldPhotos_SR()
if opt.training_dataset=='mapping':
if opt.random_hole:
dataset = dts_ray_bigfile.PairOldPhotos_with_hole()
else:
dataset = dts_ray_bigfile.PairOldPhotos()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
drop_last=True)
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)

View File

@@ -1,9 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader

View File

@@ -1,62 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)

View File

@@ -1,144 +0,0 @@
# Copyright (c) Microsoft Corporation
import torchvision.transforms as transforms
from PIL import Image, ImageFile
import torch.nn.functional as F
import torchvision as tv
import numpy as np
import warnings
import argparse
import torch
import gc
import os
from .detection_models import networks
tensor2image = transforms.ToPILImage()
warnings.filterwarnings("ignore", category=UserWarning)
ImageFile.LOAD_TRUNCATED_IMAGES = True
def data_transforms(img, full_size, method=Image.BICUBIC):
if full_size == "full_size":
ow, oh = img.size
h = int(round(oh / 16) * 16)
w = int(round(ow / 16) * 16)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
elif full_size == "scale_256":
ow, oh = img.size
pw, ph = ow, oh
if ow < oh:
ow = 256
oh = ph / pw * 256
else:
oh = 256
ow = pw / ph * 256
h = int(round(oh / 16) * 16)
w = int(round(ow / 16) * 16)
if (h == ph) and (w == pw):
return img
return img.resize((w, h), method)
def scale_tensor(img_tensor, default_scale=256):
_, _, w, h = img_tensor.shape
if w < h:
ow = default_scale
oh = h / w * default_scale
else:
oh = default_scale
ow = w / h * default_scale
oh = int(round(oh / 16) * 16)
ow = int(round(ow / 16) * 16)
return F.interpolate(img_tensor, [ow, oh], mode="bilinear")
def blend_mask(img, mask):
np_img = np.array(img).astype("float")
return Image.fromarray(
(np_img * (1 - mask) + mask * 255.0).astype("uint8")
).convert("RGB")
def main(config: argparse.Namespace, input_image: Image):
model = networks.UNet(
in_channels=1,
out_channels=1,
depth=4,
conv_num=2,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
sync_bn=True,
antialiasing=True,
)
## load model
checkpoint_path = os.path.join(
os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model_state"])
print("model weights loaded")
if torch.cuda.is_available() and config.GPU >= 0:
model.to(config.GPU)
else:
model.cpu()
model.eval()
print("processing...")
transformed_image_PIL = data_transforms(input_image, config.input_size)
input_image = transformed_image_PIL.convert("L")
input_image = tv.transforms.ToTensor()(input_image)
input_image = tv.transforms.Normalize([0.5], [0.5])(input_image)
input_image = torch.unsqueeze(input_image, 0)
_, _, ow, oh = input_image.shape
scratch_image_scale = scale_tensor(input_image)
if torch.cuda.is_available() and config.GPU >= 0:
scratch_image_scale = scratch_image_scale.to(config.GPU)
else:
scratch_image_scale = scratch_image_scale.cpu()
with torch.no_grad():
P = torch.sigmoid(model(scratch_image_scale))
P = P.data.cpu()
P = F.interpolate(P, [ow, oh], mode="nearest")
scratch_mask = torch.clamp((P >= 0.4).float(), 0.0, 1.0) * 255
gc.collect()
torch.cuda.empty_cache()
return tensor2image(scratch_mask[0].byte()), transformed_image_PIL
def global_detection(
input_image: Image,
gpu: int,
input_size: str,
) -> Image:
config = argparse.Namespace()
config.GPU = gpu
config.input_size = input_size
return main(config, input_image)

View File

@@ -1,70 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.parallel
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class Downsample(nn.Module):
# https://github.com/adobe/antialiased-cnns
def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0):
super(Downsample, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
int(1.0 * (filt_size - 1) / 2),
int(np.ceil(1.0 * (filt_size - 1) / 2)),
]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride - 1) / 2.0)
self.channels = channels
# print('Filter size [%i]'%filt_size)
if self.filt_size == 1:
a = np.array([1.0,])
elif self.filt_size == 2:
a = np.array([1.0, 1.0])
elif self.filt_size == 3:
a = np.array([1.0, 2.0, 1.0])
elif self.filt_size == 4:
a = np.array([1.0, 3.0, 3.0, 1.0])
elif self.filt_size == 5:
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
elif self.filt_size == 6:
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
elif self.filt_size == 7:
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
filt = torch.Tensor(a[:, None] * a[None, :])
filt = filt / torch.sum(filt)
self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
def forward(self, inp):
if self.filt_size == 1:
if self.pad_off == 0:
return inp[:, :, :: self.stride, :: self.stride]
else:
return self.pad(inp)[:, :, :: self.stride, :: self.stride]
else:
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
def get_pad_layer(pad_type):
if pad_type in ["refl", "reflect"]:
PadLayer = nn.ReflectionPad2d
elif pad_type in ["repl", "replicate"]:
PadLayer = nn.ReplicationPad2d
elif pad_type == "zero":
PadLayer = nn.ZeroPad2d
else:
print("Pad type [%s] not recognized" % pad_type)
return PadLayer

View File

@@ -1,332 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .sync_batchnorm import DataParallelWithCallback
from .antialiasing import Downsample
class UNet(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
depth=5,
conv_num=2,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
sync_bn=True,
antialiasing=True,
):
"""
Implementation of
U-Net: Convolutional Networks for Biomedical Image Segmentation
(Ronneberger et al., 2015)
https://arxiv.org/abs/1505.04597
Using the default arguments will yield the exact version used
in the original paper
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
depth (int): depth of the network
wf (int): number of filters in the first layer is 2**wf
padding (bool): if True, apply padding such that the input shape
is the same as the output.
This may introduce artifacts
batch_norm (bool): Use BatchNorm after layers with an
activation function
up_mode (str): one of 'upconv' or 'upsample'.
'upconv' will use transposed convolutions for
learned upsampling.
'upsample' will use bilinear upsampling.
"""
super().__init__()
assert up_mode in ("upconv", "upsample")
self.padding = padding
self.depth = depth - 1
prev_channels = in_channels
self.first = nn.Sequential(
*[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)]
)
prev_channels = 2 ** wf
self.down_path = nn.ModuleList()
self.down_sample = nn.ModuleList()
for i in range(depth):
if antialiasing and depth > 0:
self.down_sample.append(
nn.Sequential(
*[
nn.ReflectionPad2d(1),
nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(prev_channels),
nn.LeakyReLU(0.2, True),
Downsample(channels=prev_channels, stride=2),
]
)
)
else:
self.down_sample.append(
nn.Sequential(
*[
nn.ReflectionPad2d(1),
nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0),
nn.BatchNorm2d(prev_channels),
nn.LeakyReLU(0.2, True),
]
)
)
self.down_path.append(
UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm)
)
prev_channels = 2 ** (wf + i + 1)
self.up_path = nn.ModuleList()
for i in reversed(range(depth)):
self.up_path.append(
UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
)
prev_channels = 2 ** (wf + i)
if with_tanh:
self.last = nn.Sequential(
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()]
)
else:
self.last = nn.Sequential(
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)]
)
if sync_bn:
self = DataParallelWithCallback(self)
def forward(self, x):
x = self.first(x)
blocks = []
for i, down_block in enumerate(self.down_path):
blocks.append(x)
x = self.down_sample[i](x)
x = down_block(x)
for i, up in enumerate(self.up_path):
x = up(x, blocks[-i - 1])
return self.last(x)
class UNetConvBlock(nn.Module):
def __init__(self, conv_num, in_size, out_size, padding, batch_norm):
super(UNetConvBlock, self).__init__()
block = []
for _ in range(conv_num):
block.append(nn.ReflectionPad2d(padding=int(padding)))
block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0))
if batch_norm:
block.append(nn.BatchNorm2d(out_size))
block.append(nn.LeakyReLU(0.2, True))
in_size = out_size
self.block = nn.Sequential(*block)
def forward(self, x):
out = self.block(x)
return out
class UNetUpBlock(nn.Module):
def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm):
super(UNetUpBlock, self).__init__()
if up_mode == "upconv":
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
elif up_mode == "upsample":
self.up = nn.Sequential(
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False),
nn.ReflectionPad2d(1),
nn.Conv2d(in_size, out_size, kernel_size=3, padding=0),
)
self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm)
def center_crop(self, layer, target_size):
_, _, layer_height, layer_width = layer.size()
diff_y = (layer_height - target_size[0]) // 2
diff_x = (layer_width - target_size[1]) // 2
return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]
def forward(self, x, bridge):
up = self.up(x)
crop1 = self.center_crop(bridge, up.shape[2:])
out = torch.cat([up, crop1], 1)
out = self.conv_block(out)
return out
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super().__init__()
if norm_type == "BN":
norm_layer = nn.BatchNorm2d
elif norm_type == "IN":
norm_layer = nn.InstanceNorm2d
else:
raise NameError("Unknown norm layer")
# construct unet structure
unet_block = UnetSkipConnectionBlock(
ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True
) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(
ngf * 8,
ngf * 8,
input_nc=None,
submodule=unet_block,
norm_layer=norm_layer,
use_dropout=use_dropout,
)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
unet_block = UnetSkipConnectionBlock(
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
unet_block = UnetSkipConnectionBlock(
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
)
self.model = UnetSkipConnectionBlock(
output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer
) # add the outermost layer
def forward(self, input):
return self.model(input)
class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
-------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""
def __init__(
self,
outer_nc,
inner_nc,
input_nc=None,
submodule=None,
outermost=False,
innermost=False,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
):
"""Construct a Unet submodule with skip connections.
Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
user_dropout (bool) -- if use dropout layers.
"""
super().__init__()
self.outermost = outermost
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.LeakyReLU(0.2, True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(
inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)
# ============================================
# Network testing
# ============================================
if __name__ == "__main__":
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet_two_decoders(
in_channels=3,
out_channels1=3,
out_channels2=1,
depth=4,
conv_num=1,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
)
model.to(device)
model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False)
model_pix2pix.to(device)
print("customized unet:")
summary(model, (3, 256, 256))
print("cyclegan unet:")
summary(model_pix2pix, (3, 256, 256))
x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda()
g = make_dot(model(x))
g.render("models/Digraph.gv", view=False)

View File

@@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
# File : __init__.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
from .batchnorm import set_sbn_eps_mode
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .batchnorm import patch_sync_batchnorm, convert_model
from .replicate import DataParallelWithCallback, patch_replication_callback

View File

@@ -1,412 +0,0 @@
# -*- coding: utf-8 -*-
# File : batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import collections
import contextlib
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
try:
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
except ImportError:
ReduceAddCoalesced = Broadcast = None
try:
from jactorch.parallel.comm import SyncMaster
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
except ImportError:
from .comm import SyncMaster
from .replicate import DataParallelWithCallback
__all__ = [
'set_sbn_eps_mode',
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
'patch_sync_batchnorm', 'convert_model'
]
SBN_EPS_MODE = 'clamp'
def set_sbn_eps_mode(mode):
global SBN_EPS_MODE
assert mode in ('clamp', 'plus')
SBN_EPS_MODE = mode
def _sum_ft(tensor):
"""sum over the first and last dimention"""
return tensor.sum(dim=0).sum(dim=-1)
def _unsqueeze_ft(tensor):
"""add new dimensions at the front and the tail"""
return tensor.unsqueeze(0).unsqueeze(-1)
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SynchronizedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
track_running_stats=track_running_stats)
if not self.track_running_stats:
import warnings
warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
if hasattr(torch, 'no_grad'):
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
else:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
if SBN_EPS_MODE == 'clamp':
return mean, bias_var.clamp(self.eps) ** -0.5
elif SBN_EPS_MODE == 'plus':
return mean, (bias_var + self.eps) ** -0.5
else:
raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
mini-batch.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm1d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
of 3d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm2d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
of 4d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm3d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
or Spatio-temporal BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x depth x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
@contextlib.contextmanager
def patch_sync_batchnorm():
import torch.nn as nn
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
nn.BatchNorm1d = SynchronizedBatchNorm1d
nn.BatchNorm2d = SynchronizedBatchNorm2d
nn.BatchNorm3d = SynchronizedBatchNorm3d
yield
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
def convert_model(module):
"""Traverse the input module and its child recursively
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
to SynchronizedBatchNorm*N*d
Args:
module: the input module needs to be convert to SyncBN model
Examples:
>>> import torch.nn as nn
>>> import torchvision
>>> # m is a standard pytorch model
>>> m = torchvision.models.resnet18(True)
>>> m = nn.DataParallel(m)
>>> # after convert, m is using SyncBN
>>> m = convert_model(m)
"""
if isinstance(module, torch.nn.DataParallel):
mod = module.module
mod = convert_model(mod)
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
return mod
mod = module
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.BatchNorm3d],
[SynchronizedBatchNorm1d,
SynchronizedBatchNorm2d,
SynchronizedBatchNorm3d]):
if isinstance(module, pth_module):
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_model(child))
return mod

View File

@@ -1,74 +0,0 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : batchnorm_reimpl.py
# Author : acgtyrant
# Date : 11/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import torch
import torch.nn as nn
import torch.nn.init as init
__all__ = ['BatchNorm2dReimpl']
class BatchNorm2dReimpl(nn.Module):
"""
A re-implementation of batch normalization, used for testing the numerical
stability.
Author: acgtyrant
See also:
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.weight = nn.Parameter(torch.empty(num_features))
self.bias = nn.Parameter(torch.empty(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_running_stats(self):
self.running_mean.zero_()
self.running_var.fill_(1)
def reset_parameters(self):
self.reset_running_stats()
init.uniform_(self.weight)
init.zeros_(self.bias)
def forward(self, input_):
batchsize, channels, height, width = input_.size()
numel = batchsize * height * width
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
sum_ = input_.sum(1)
sum_of_square = input_.pow(2).sum(1)
mean = sum_ / numel
sumvar = sum_of_square - sum_ * mean
self.running_mean = (
(1 - self.momentum) * self.running_mean
+ self.momentum * mean.detach()
)
unbias_var = sumvar / (numel - 1)
self.running_var = (
(1 - self.momentum) * self.running_var
+ self.momentum * unbias_var.detach()
)
bias_var = sumvar / numel
inv_std = 1 / (bias_var + self.eps).pow(0.5)
output = (
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()

View File

@@ -1,137 +0,0 @@
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)

View File

@@ -1,94 +0,0 @@
# -*- coding: utf-8 -*-
# File : replicate.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'execute_replication_callbacks',
'DataParallelWithCallback',
'patch_replication_callback'
]
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
class DataParallelWithCallback(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
def replicate(self, module, device_ids):
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate

View File

@@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# File : unittest.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import unittest
import torch
class TorchTestCase(unittest.TestCase):
def assertTensorClose(self, x, y):
adiff = float((x - y).abs().max())
if (y == 0).all():
rdiff = 'NaN'
else:
rdiff = float((adiff / y).abs().max())
message = (
'Tensor close check failed\n'
'adiff={}\n'
'rdiff={}\n'
).format(adiff, rdiff)
self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)

View File

@@ -1,245 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import sys
import time
import shutil
import platform
import numpy as np
from datetime import datetime
import torch
import torchvision as tv
import torch.backends.cudnn as cudnn
# from torch.utils.tensorboard import SummaryWriter
import yaml
import matplotlib.pyplot as plt
from easydict import EasyDict as edict
import torchvision.utils as vutils
##### option parsing ######
def print_options(config_dict):
print("------------ Options -------------")
for k, v in sorted(config_dict.items()):
print("%s: %s" % (str(k), str(v)))
print("-------------- End ----------------")
def save_options(config_dict):
from time import gmtime, strftime
file_dir = os.path.join(config_dict["checkpoint_dir"], config_dict["name"])
mkdir_if_not(file_dir)
file_name = os.path.join(file_dir, "opt.txt")
with open(file_name, "wt") as opt_file:
opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n")
opt_file.write("------------ Options -------------\n")
for k, v in sorted(config_dict.items()):
opt_file.write("%s: %s\n" % (str(k), str(v)))
opt_file.write("-------------- End ----------------\n")
def config_parse(config_file, options, save=True):
with open(config_file, "r") as stream:
config_dict = yaml.safe_load(stream)
config = edict(config_dict)
for option_key, option_value in vars(options).items():
config_dict[option_key] = option_value
config[option_key] = option_value
if config.debug_mode:
config_dict["num_workers"] = 0
config.num_workers = 0
config.batch_size = 2
if isinstance(config.gpu_ids, str):
config.gpu_ids = [int(x) for x in config.gpu_ids.split(",")][0]
print_options(config_dict)
if save:
save_options(config_dict)
return config
###### utility ######
def to_np(x):
return x.cpu().numpy()
def prepare_device(use_gpu, gpu_ids):
if torch.cuda.is_available() and use_gpu:
cudnn.benchmark = True
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
if isinstance(gpu_ids, str):
gpu_ids = [int(x) for x in gpu_ids.split(",")]
torch.cuda.set_device(gpu_ids[0])
device = torch.device("cuda:" + str(gpu_ids[0]))
else:
torch.cuda.set_device(gpu_ids)
device = torch.device("cuda:" + str(gpu_ids))
print("running on GPU {}".format(gpu_ids))
else:
device = torch.device("cpu")
print("running on CPU")
return device
###### file system ######
def get_dir_size(start_path="."):
total_size = 0
for dirpath, dirnames, filenames in os.walk(start_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
return total_size
def mkdir_if_not(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
##### System related ######
class Timer:
def __init__(self, msg):
self.msg = msg
self.start_time = None
def __enter__(self):
self.start_time = time.time()
def __exit__(self, exc_type, exc_value, exc_tb):
elapse = time.time() - self.start_time
print(self.msg % elapse)
###### interactive ######
def get_size(start_path="."):
total_size = 0
for dirpath, dirnames, filenames in os.walk(start_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
return total_size
def clean_tensorboard(directory):
tensorboard_list = os.listdir(directory)
SIZE_THRESH = 100000
for tensorboard in tensorboard_list:
tensorboard = os.path.join(directory, tensorboard)
if get_size(tensorboard) < SIZE_THRESH:
print("deleting the empty tensorboard: ", tensorboard)
#
if os.path.isdir(tensorboard):
shutil.rmtree(tensorboard)
else:
os.remove(tensorboard)
def prepare_tensorboard(config, experiment_name=datetime.now().strftime("%Y-%m-%d %H-%M-%S")):
tensorboard_directory = os.path.join(config.checkpoint_dir, config.name, "tensorboard_logs")
mkdir_if_not(tensorboard_directory)
clean_tensorboard(tensorboard_directory)
tb_writer = SummaryWriter(os.path.join(tensorboard_directory, experiment_name), flush_secs=10)
# try:
# shutil.copy('outputs/opt.txt', tensorboard_directory)
# except:
# print('cannot find file opt.txt')
return tb_writer
def tb_loss_logger(tb_writer, iter_index, loss_logger):
for tag, value in loss_logger.items():
tb_writer.add_scalar(tag, scalar_value=value.item(), global_step=iter_index)
def tb_image_logger(tb_writer, iter_index, images_info, config):
### Save and write the output into the tensorboard
tb_logger_path = os.path.join(config.output_dir, config.name, config.train_mode)
mkdir_if_not(tb_logger_path)
for tag, image in images_info.items():
if tag == "test_image_prediction" or tag == "image_prediction":
continue
image = tv.utils.make_grid(image.cpu())
image = torch.clamp(image, 0, 1)
tb_writer.add_image(tag, img_tensor=image, global_step=iter_index)
tv.transforms.functional.to_pil_image(image).save(
os.path.join(tb_logger_path, "{:06d}_{}.jpg".format(iter_index, tag))
)
def tb_image_logger_test(epoch, iter, images_info, config):
url = os.path.join(config.output_dir, config.name, config.train_mode, "val_" + str(epoch))
if not os.path.exists(url):
os.makedirs(url)
scratch_img = images_info["test_scratch_image"].data.cpu()
if config.norm_input:
scratch_img = (scratch_img + 1.0) / 2.0
scratch_img = torch.clamp(scratch_img, 0, 1)
gt_mask = images_info["test_mask_image"].data.cpu()
predict_mask = images_info["test_scratch_prediction"].data.cpu()
predict_hard_mask = (predict_mask.data.cpu() >= 0.5).float()
imgs = torch.cat((scratch_img, predict_hard_mask, gt_mask), 0)
img_grid = vutils.save_image(
imgs, os.path.join(url, str(iter) + ".jpg"), nrow=len(scratch_img), padding=0, normalize=True
)
def imshow(input_image, title=None, to_numpy=False):
inp = input_image
if to_numpy or type(input_image) is torch.Tensor:
inp = input_image.numpy()
fig = plt.figure()
if inp.ndim == 2:
fig = plt.imshow(inp, cmap="gray", clim=[0, 255])
else:
fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8))
plt.axis("off")
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
plt.title(title)
###### vgg preprocessing ######
def vgg_preprocess(tensor):
# input is RGB tensor which ranges in [0,1]
# output is BGR tensor which ranges in [0,255]
tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
# tensor_bgr = tensor[:, [2, 1, 0], ...]
tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(
1, 3, 1, 1
)
tensor_rst = tensor_bgr_ml * 255
return tensor_rst
def torch_vgg_preprocess(tensor):
# pytorch version normalization
# note that both input and output are RGB tensors;
# input and output ranges in [0,1]
# normalize the tensor with mean and variance
tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1)
tensor_mc_norm = tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1)
return tensor_mc_norm
def network_gradient(net, gradient_on=True):
if gradient_on:
for param in net.parameters():
param.requires_grad = True
else:
for param in net.parameters():
param.requires_grad = False
return net

View File

@@ -1,204 +0,0 @@
# Copyright (c) Microsoft Corporation
import torch.nn as nn
from . import networks
class Mapping_Model_with_mask(nn.Module):
def __init__(
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
):
super(Mapping_Model_with_mask, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
for i in range(n_up):
ic = min(tmp_nc * (2**i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
self.before_NL = nn.Sequential(*model)
if opt.NL_res:
self.NL = networks.NonLocalBlock2D_with_mask_Res(
mc,
mc,
opt.NL_fusion_method,
opt.correlation_renormalize,
opt.softmax_temperature,
opt.use_self,
opt.cosin_similarity,
)
print("using NL + Res...")
model = []
for i in range(n_blocks):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [
norm_layer(tmp_nc),
activation,
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.after_NL = nn.Sequential(*model)
def forward(self, input, mask):
x1 = self.before_NL(input)
del input
x2 = self.NL(x1, mask)
del x1, mask
x3 = self.after_NL(x2)
del x2
return x3
class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention
def __init__(
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
):
super(Mapping_Model_with_mask_2, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
for i in range(n_up):
ic = min(tmp_nc * (2**i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
print("using multi-scale patch attention, conv combine + mask input...")
self.before_NL = nn.Sequential(*model)
if opt.mapping_exp == 1:
self.NL_scale_1 = networks.Patch_Attention_4(mc, mc, 8)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
self.res_block_1 = nn.Sequential(*model)
if opt.mapping_exp == 1:
self.NL_scale_2 = networks.Patch_Attention_4(mc, mc, 4)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
self.res_block_2 = nn.Sequential(*model)
if opt.mapping_exp == 1:
self.NL_scale_3 = networks.Patch_Attention_4(mc, mc, 2)
# self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [
norm_layer(tmp_nc),
activation,
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.after_NL = nn.Sequential(*model)
def forward(self, input, mask):
x1 = self.before_NL(input)
x2 = self.NL_scale_1(x1, mask)
x3 = self.res_block_1(x2)
x4 = self.NL_scale_2(x3, mask)
x5 = self.res_block_2(x4)
x6 = self.NL_scale_3(x5, mask)
x7 = self.after_NL(x6)
return x7
def inference_forward(self, input, mask):
x1 = self.before_NL(input)
del input
x2 = self.NL_scale_1.inference_forward(x1, mask)
del x1
x3 = self.res_block_1(x2)
del x2
x4 = self.NL_scale_2.inference_forward(x3, mask)
del x3
x5 = self.res_block_2(x4)
del x4
x6 = self.NL_scale_3.inference_forward(x5, mask)
del x5
x7 = self.after_NL(x6)
del x6
return x7

View File

@@ -1,122 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import torch
import sys
class BaseModel(torch.nn.Module):
def name(self):
return "BaseModel"
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if (torch.cuda.is_available() and self.gpu_ids) else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda()
def save_optimizer(self, optimizer, optimizer_label, epoch_label):
save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(optimizer.state_dict(), save_path)
def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""):
save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print("%s not exists yet!" % save_path)
else:
optimizer.load_state_dict(torch.load(save_path))
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label, save_dir=""):
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
if not save_dir:
save_dir = self.save_dir
# print(save_dir)
# print(self.save_dir)
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print("%s not exists yet!" % save_path)
# if network_label == 'G':
# raise('Generator must exist!')
else:
# network.load_state_dict(torch.load(save_path))
try:
# print(save_path)
network.load_state_dict(torch.load(save_path))
except:
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
# if self.opt.verbose:
print(
"Pretrained network %s has excessive layers; Only loading layers that are used"
% network_label
)
except:
print(
"Pretrained network %s has fewer layers; The following are not initialized:"
% network_label
)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
if sys.version_info >= (3, 0):
not_initialized = set()
else:
from sets import Set
not_initialized = Set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split(".")[0])
print(sorted(not_initialized))
network.load_state_dict(model_dict)
def update_learning_rate():
pass

View File

@@ -1,453 +0,0 @@
# Copyright (c) Microsoft Corporation
from .NonLocal_feature_mapping_model import *
from ..util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
from torch.autograd import Variable
import torch.nn as nn
import torch
class Mapping_Model(nn.Module):
def __init__(
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
):
super(Mapping_Model, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
# print("Mapping: You are using the mapping model without global restoration.")
for i in range(n_up):
ic = min(tmp_nc * (2**i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
for i in range(n_blocks):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [
norm_layer(tmp_nc),
activation,
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class Pix2PixHDModel_Mapping(BaseModel):
def name(self):
return "Pix2PixHDModel_Mapping"
def init_loss_filter(
self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2
):
flags = (
True,
True,
use_gan_feat_loss,
use_vgg_loss,
True,
True,
use_smooth_l1,
stage_1_feat_l2,
)
def loss_filter(
g_feat_l2,
g_gan,
g_gan_feat,
g_vgg,
d_real,
d_fake,
smooth_l1,
stage_1_feat_l2,
):
return [
l
for (l, f) in zip(
(
g_feat_l2,
g_gan,
g_gan_feat,
g_vgg,
d_real,
d_fake,
smooth_l1,
stage_1_feat_l2,
),
flags,
)
if f
]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != "none" or not opt.isTrain:
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
##### define networks
# Generator network
netG_input_nc = input_nc
self.netG_A = networks.GlobalGenerator_DCDCv2(
netG_input_nc,
opt.output_nc,
opt.ngf,
opt.k_size,
opt.n_downsample_global,
networks.get_norm_layer(norm_type=opt.norm),
opt=opt,
)
self.netG_B = networks.GlobalGenerator_DCDCv2(
netG_input_nc,
opt.output_nc,
opt.ngf,
opt.k_size,
opt.n_downsample_global,
networks.get_norm_layer(norm_type=opt.norm),
opt=opt,
)
if opt.non_local == "Setting_42" or opt.NL_use_mask:
if opt.mapping_exp == 1:
self.mapping_net = Mapping_Model_with_mask_2(
min(opt.ngf * 2**opt.n_downsample_global, opt.mc),
opt.map_mc,
n_blocks=opt.mapping_n_block,
opt=opt,
)
else:
self.mapping_net = Mapping_Model_with_mask(
min(opt.ngf * 2**opt.n_downsample_global, opt.mc),
opt.map_mc,
n_blocks=opt.mapping_n_block,
opt=opt,
)
else:
self.mapping_net = Mapping_Model(
min(opt.ngf * 2**opt.n_downsample_global, opt.mc),
opt.map_mc,
n_blocks=opt.mapping_n_block,
opt=opt,
)
self.mapping_net.apply(networks.weights_init)
if opt.load_pretrain != "":
self.load_network(
self.mapping_net, "mapping_net", opt.which_epoch, opt.load_pretrain
)
if not opt.no_load_VAE:
self.load_network(
self.netG_A, "G", opt.use_vae_which_epoch, opt.load_pretrainA
)
self.load_network(
self.netG_B, "G", opt.use_vae_which_epoch, opt.load_pretrainB
)
for param in self.netG_A.parameters():
param.requires_grad = False
for param in self.netG_B.parameters():
param.requires_grad = False
self.netG_A.eval()
self.netG_B.eval()
if torch.cuda.is_available() and opt.gpu_ids:
self.netG_A.cuda(opt.gpu_ids[0])
self.netG_B.cuda(opt.gpu_ids[0])
self.mapping_net.cuda(opt.gpu_ids[0])
if not self.isTrain:
self.load_network(self.mapping_net, "mapping_net", opt.which_epoch)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(
netD_input_nc,
opt.ndf,
opt.n_layers_D,
opt,
opt.norm,
use_sigmoid,
opt.num_D,
not opt.no_ganFeat_loss,
gpu_ids=self.gpu_ids,
)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(
not opt.no_ganFeat_loss,
not opt.no_vgg_loss,
opt.Smooth_L1,
opt.use_two_stage_mapping,
)
self.criterionGAN = networks.GANLoss(
use_lsgan=not opt.no_lsgan, tensor=self.Tensor
)
self.criterionFeat = torch.nn.L1Loss()
self.criterionFeat_feat = (
torch.nn.L1Loss() if opt.use_l1_feat else torch.nn.MSELoss()
)
if self.opt.image_L1:
self.criterionImage = torch.nn.L1Loss()
else:
self.criterionImage = torch.nn.SmoothL1Loss()
print(self.criterionFeat_feat)
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)
# Names so we can breakout loss
self.loss_names = self.loss_filter(
"G_Feat_L2",
"G_GAN",
"G_GAN_Feat",
"G_VGG",
"D_real",
"D_fake",
"Smooth_L1",
"G_Feat_L2_Stage_1",
)
# initialize optimizers
# optimizer G
if opt.no_TTUR:
beta1, beta2 = opt.beta1, 0.999
G_lr, D_lr = opt.lr, opt.lr
else:
beta1, beta2 = 0, 0.9
G_lr, D_lr = opt.lr / 2, opt.lr * 2
if not opt.no_load_VAE:
params = list(self.mapping_net.parameters())
self.optimizer_mapping = torch.optim.Adam(
params, lr=G_lr, betas=(beta1, beta2)
)
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2))
print("---------- Optimizers initialized -------------")
def encode_input(
self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False
):
if self.opt.label_nc == 0:
input_label = label_map.data.cuda()
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
return input_label, inst_map, real_image, feat_map
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(
self,
label,
inst,
image,
feat,
pair=True,
infer=False,
last_label=None,
last_image=None,
):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(
label, inst, image, feat
)
# Fake Generation
input_concat = input_label
label_feat = self.netG_A.forward(input_concat, flow="enc")
# print('label:')
# print(label_feat.min(), label_feat.max(), label_feat.mean())
# label_feat = label_feat / 16.0
if self.opt.NL_use_mask:
label_feat_map = self.mapping_net(label_feat.detach(), inst)
else:
label_feat_map = self.mapping_net(label_feat.detach())
fake_image = self.netG_B.forward(label_feat_map, flow="dec")
image_feat = self.netG_B.forward(real_image, flow="enc")
loss_feat_l2_stage_1 = 0
loss_feat_l2 = (
self.criterionFeat_feat(label_feat_map, image_feat.data) * self.opt.l2_feat
)
if self.opt.feat_gan:
# Fake Detection and Loss
pred_fake_pool = self.discriminate(
label_feat.detach(), label_feat_map, use_pool=True
)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(label_feat.detach(), image_feat)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(
torch.cat((label_feat.detach(), label_feat_map), dim=1)
)
loss_G_GAN = self.criterionGAN(pred_fake, True)
else:
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
if pair:
pred_real = self.discriminate(input_label, real_image)
else:
pred_real = self.discriminate(last_label, last_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss and pair:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i]) - 1):
tmp = (
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
* self.opt.lambda_feat
)
loss_G_GAN_Feat += D_weights * feat_weights * tmp
else:
loss_G_GAN_Feat = torch.zeros(1).to(label.device)
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = (
self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
if pair
else torch.zeros(1).to(label.device)
)
smooth_l1_loss = 0
if self.opt.Smooth_L1:
smooth_l1_loss = (
self.criterionImage(fake_image, real_image) * self.opt.L1_weight
)
return [
self.loss_filter(
loss_feat_l2,
loss_G_GAN,
loss_G_GAN_Feat,
loss_G_VGG,
loss_D_real,
loss_D_fake,
smooth_l1_loss,
loss_feat_l2_stage_1,
),
None if not infer else fake_image,
]
def inference(self, label, inst):
use_gpu = torch.cuda.is_available() and len(self.opt.gpu_ids) > 0
if use_gpu:
input_concat = label.data.cuda()
inst_data = inst.cuda()
else:
input_concat = label.data
inst_data = inst
label_feat = self.netG_A.forward(input_concat, flow="enc")
if self.opt.NL_use_mask:
if self.opt.inference_optimize:
label_feat_map = self.mapping_net.inference_forward(
label_feat.detach(), inst_data
)
else:
label_feat_map = self.mapping_net(label_feat.detach(), inst_data)
else:
label_feat_map = self.mapping_net(label_feat.detach())
fake_image = self.netG_B.forward(label_feat_map, flow="dec")
return fake_image
class InferenceModel(Pix2PixHDModel_Mapping):
def forward(self, label, inst):
return self.inference(label, inst)

View File

@@ -1,39 +0,0 @@
# Copyright (c) Microsoft Corporation
def create_model(opt):
assert opt.model == "pix2pixHD"
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
if opt.isTrain:
model = Pix2PixHDModel()
else:
model = InferenceModel()
model.initialize(opt)
if opt.verbose:
print("model [%s] was created" % (model.name()))
assert not opt.isTrain
return model
def create_da_model(opt):
assert opt.model == "pix2pixHD"
from .pix2pixHD_model_DA import Pix2PixHDModel, InferenceModel
if opt.isTrain:
model = Pix2PixHDModel()
else:
model = InferenceModel()
model.initialize(opt)
if opt.verbose:
print("model [%s] was created" % (model.name()))
assert not opt.isTrain
return model

View File

@@ -1,875 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
from torch.nn.utils import spectral_norm
# from util.util import SwitchNorm2d
import torch.nn.functional as F
###############################################################################
# Functions
###############################################################################
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type="instance"):
if norm_type == "batch":
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == "instance":
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == "spectral":
norm_layer = spectral_norm()
elif norm_type == "SwitchNorm":
norm_layer = SwitchNorm2d
else:
raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
return norm_layer
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print("Total number of parameters: %d" % num_params)
def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
n_blocks_local=3, norm='instance', gpu_ids=[], opt=None):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
# if opt.self_gen:
if opt.use_v2:
netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt)
else:
netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt)
else:
raise('generator not implemented!')
print(netG)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
print(netD)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netD.cuda(gpu_ids[0])
netD.apply(weights_init)
return netD
class GlobalGenerator_DCDCv2(nn.Module):
def __init__(
self,
input_nc,
output_nc,
ngf=64,
k_size=3,
n_downsampling=8,
norm_layer=nn.BatchNorm2d,
padding_type="reflect",
opt=None,
):
super(GlobalGenerator_DCDCv2, self).__init__()
activation = nn.ReLU(True)
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0),
norm_layer(ngf),
activation,
]
### downsample
for i in range(opt.start_r):
mult = 2 ** i
model += [
nn.Conv2d(
min(ngf * mult, opt.mc),
min(ngf * mult * 2, opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
),
norm_layer(min(ngf * mult * 2, opt.mc)),
activation,
]
for i in range(opt.start_r, n_downsampling - 1):
mult = 2 ** i
model += [
nn.Conv2d(
min(ngf * mult, opt.mc),
min(ngf * mult * 2, opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
),
norm_layer(min(ngf * mult * 2, opt.mc)),
activation,
]
model += [
ResnetBlock(
min(ngf * mult * 2, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
model += [
ResnetBlock(
min(ngf * mult * 2, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
mult = 2 ** (n_downsampling - 1)
if opt.spatio_size == 32:
model += [
nn.Conv2d(
min(ngf * mult, opt.mc),
min(ngf * mult * 2, opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
),
norm_layer(min(ngf * mult * 2, opt.mc)),
activation,
]
if opt.spatio_size == 64:
model += [
ResnetBlock(
min(ngf * mult * 2, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
model += [
ResnetBlock(
min(ngf * mult * 2, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
# model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)]
if opt.feat_dim > 0:
model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)]
self.encoder = nn.Sequential(*model)
# decode
model = []
if opt.feat_dim > 0:
model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)]
# model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)]
o_pad = 0 if k_size == 4 else 1
mult = 2 ** n_downsampling
model += [
ResnetBlock(
min(ngf * mult, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
if opt.spatio_size == 32:
model += [
nn.ConvTranspose2d(
min(ngf * mult, opt.mc),
min(int(ngf * mult / 2), opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
output_padding=o_pad,
),
norm_layer(min(int(ngf * mult / 2), opt.mc)),
activation,
]
if opt.spatio_size == 64:
model += [
ResnetBlock(
min(ngf * mult, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
for i in range(1, n_downsampling - opt.start_r):
mult = 2 ** (n_downsampling - i)
model += [
ResnetBlock(
min(ngf * mult, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
model += [
ResnetBlock(
min(ngf * mult, opt.mc),
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
)
]
model += [
nn.ConvTranspose2d(
min(ngf * mult, opt.mc),
min(int(ngf * mult / 2), opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
output_padding=o_pad,
),
norm_layer(min(int(ngf * mult / 2), opt.mc)),
activation,
]
for i in range(n_downsampling - opt.start_r, n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
nn.ConvTranspose2d(
min(ngf * mult, opt.mc),
min(int(ngf * mult / 2), opt.mc),
kernel_size=k_size,
stride=2,
padding=1,
output_padding=o_pad,
),
norm_layer(min(int(ngf * mult / 2), opt.mc)),
activation,
]
if opt.use_segmentation_model:
model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)]
else:
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0),
nn.Tanh(),
]
self.decoder = nn.Sequential(*model)
def forward(self, input, flow="enc_dec"):
if flow == "enc":
return self.encoder(input)
elif flow == "dec":
return self.decoder(input)
elif flow == "enc_dec":
x = self.encoder(input)
x = self.decoder(x)
return x
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(
self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1
):
super(ResnetBlock, self).__init__()
self.opt = opt
self.dilation = dilation
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(self.dilation)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(self.dilation)]
elif padding_type == "zero":
p = self.dilation
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation),
norm_layer(dim),
activation,
]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class Encoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
super(Encoder, self).__init__()
self.output_nc = output_nc
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True),
]
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True),
]
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [
nn.ConvTranspose2d(
ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1
),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True),
]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input, inst):
outputs = self.model(input)
# instance-wise average pooling
outputs_mean = outputs.clone()
inst_list = np.unique(inst.cpu().numpy().astype(int))
for i in inst_list:
for b in range(input.size()[0]):
indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4
for j in range(self.output_nc):
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]]
mean_feat = torch.mean(output_ins).expand_as(output_ins)
outputs_mean[
indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]
] = mean_feat
return outputs_mean
def SN(module, mode=True):
if mode:
return torch.nn.utils.spectral_norm(module)
return module
class NonLocalBlock2D_with_mask_Res(nn.Module):
def __init__(
self,
in_channels,
inter_channels,
mode="add",
re_norm=False,
temperature=1.0,
use_self=False,
cosin=False,
):
super(NonLocalBlock2D_with_mask_Res, self).__init__()
self.cosin = cosin
self.renorm = re_norm
self.in_channels = in_channels
self.inter_channels = inter_channels
self.g = nn.Conv2d(
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
)
self.W = nn.Conv2d(
in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0
)
# for pytorch 0.3.1
# nn.init.constant(self.W.weight, 0)
# nn.init.constant(self.W.bias, 0)
# for pytorch 0.4.0
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = nn.Conv2d(
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
)
self.phi = nn.Conv2d(
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
)
self.mode = mode
self.temperature = temperature
self.use_self = use_self
norm_layer = get_norm_layer(norm_type="instance")
activation = nn.ReLU(True)
model = []
for i in range(3):
model += [
ResnetBlock(
inter_channels,
padding_type="reflect",
activation=activation,
norm_layer=norm_layer,
opt=None,
)
]
self.res_block = nn.Sequential(*model)
def forward(self, x, mask): ## The shape of mask is Batch*1*H*W
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
if self.cosin:
theta_x = F.normalize(theta_x, dim=2)
phi_x = F.normalize(phi_x, dim=1)
f = torch.matmul(theta_x, phi_x)
f /= self.temperature
f_div_C = F.softmax(f, dim=2)
tmp = 1 - mask
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
mask[mask > 0] = 1.0
mask = 1 - mask
tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
mask *= tmp
mask_expand = mask.view(batch_size, 1, -1)
mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1)
# mask = 1 - mask
# mask=F.interpolate(mask,(x.size(2),x.size(3)))
# mask_expand=mask.view(batch_size,1,-1)
# mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1)
if self.use_self:
mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0
# print(mask_expand.shape)
# print(f_div_C.shape)
f_div_C = mask_expand * f_div_C
if self.renorm:
f_div_C = F.normalize(f_div_C, p=1, dim=2)
###########################
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
W_y = self.res_block(W_y)
if self.mode == "combine":
full_mask = mask.repeat(1, self.inter_channels, 1, 1)
z = full_mask * x + (1 - full_mask) * W_y
return z
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
if getIntermFeat:
for j in range(n_layers+2):
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
else:
setattr(self, 'layer'+str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
else:
model = getattr(self, 'layer'+str(num_D-1-i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D-1):
input_downsampled = self.downsample(input_downsampled)
return result
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)
class Patch_Attention_4(nn.Module): ## While combine the feature map, use conv and mask
def __init__(self, in_channels, inter_channels, patch_size):
super(Patch_Attention_4, self).__init__()
self.patch_size=patch_size
# self.g = nn.Conv2d(
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
# )
# self.W = nn.Conv2d(
# in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0
# )
# # for pytorch 0.3.1
# # nn.init.constant(self.W.weight, 0)
# # nn.init.constant(self.W.bias, 0)
# # for pytorch 0.4.0
# nn.init.constant_(self.W.weight, 0)
# nn.init.constant_(self.W.bias, 0)
# self.theta = nn.Conv2d(
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
# )
# self.phi = nn.Conv2d(
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
# )
self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True)
norm_layer = get_norm_layer(norm_type="instance")
activation = nn.ReLU(True)
model = []
for i in range(1):
model += [
ResnetBlock(
inter_channels,
padding_type="reflect",
activation=activation,
norm_layer=norm_layer,
opt=None,
)
]
self.res_block = nn.Sequential(*model)
def Hard_Compose(self, input, dim, index):
# batch index select
# input: [B,C,HW]
# dim: scalar > 0
# index: [B, HW]
views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))]
expanse = list(input.size())
expanse[0] = -1
expanse[dim] = -1
index = index.view(views).expand(expanse)
return torch.gather(input, dim, index)
def forward(self, z, mask): ## The shape of mask is Batch*1*H*W
x=self.res_block(z)
b,c,h,w=x.shape
## mask resize + dilation
# tmp = 1 - mask
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
mask[mask > 0] = 1.0
# mask = 1 - mask
# tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
# mask *= tmp
# mask=1-mask
## 1: mask position 0: non-mask
mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()
all_patch_num=h*w/self.patch_size/self.patch_size
non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1)
x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
y_unfold=x_unfold.permute(0,2,1)
x_unfold_normalized=F.normalize(x_unfold,dim=1)
y_unfold_normalized=F.normalize(y_unfold,dim=2)
correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized)
correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9)
correlation_matrix=F.softmax(correlation_matrix,dim=2)
# print(correlation_matrix)
R, max_arg=torch.max(correlation_matrix,dim=2)
composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg)
composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size)
concat_1=torch.cat((z,composed_fold,mask),dim=1)
concat_1=self.F_Combine(concat_1)
return concat_1
def inference_forward(self,z,mask): ## Reduce the extra memory cost
x=self.res_block(z)
b,c,h,w=x.shape
## mask resize + dilation
# tmp = 1 - mask
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
mask[mask > 0] = 1.0
# mask = 1 - mask
# tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
# mask *= tmp
# mask=1-mask
## 1: mask position 0: non-mask
mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num
all_patch_num=h*w/self.patch_size/self.patch_size
mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0]
if len(mask_index)==0: ## No mask patch is selected, no attention is needed
composed_fold=x
else:
unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0]
x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
Query_Patch=torch.index_select(x_unfold,2,mask_index)
Key_Patch=torch.index_select(x_unfold,2,unmask_index)
Query_Patch=Query_Patch.permute(0,2,1)
Query_Patch_normalized=F.normalize(Query_Patch,dim=2)
Key_Patch_normalized=F.normalize(Key_Patch,dim=1)
correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized)
correlation_matrix=F.softmax(correlation_matrix,dim=2)
R, max_arg=torch.max(correlation_matrix,dim=2)
composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg)
x_unfold[:,:,mask_index]=composed_unfold
composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size)
concat_1=torch.cat((z,composed_fold,mask),dim=1)
concat_1=self.F_Combine(concat_1)
return concat_1
##############################################################################
# Losses
##############################################################################
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
####################################### VGG Loss
from torchvision import models
class VGG19_torch(torch.nn.Module):
def __init__(self, requires_grad=False):
super(VGG19_torch, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class VGGLoss_torch(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss_torch, self).__init__()
self.vgg = VGG19_torch().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss

View File

@@ -1,333 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_L1):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True,use_smooth_L1)
def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake,smooth_l1):
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg, g_kl, d_real,d_fake,smooth_l1),flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false
self.gen_features = self.use_features and not self.opt.load_features ## it is also false
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel #
##### define networks
# Generator network
netG_input_nc = input_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
print("---------- G Networks reloaded -------------")
if self.isTrain:
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
print("---------- D Networks reloaded -------------")
if self.gen_features:
self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0!
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionFeat = torch.nn.L1Loss()
# self.criterionImage = torch.nn.SmoothL1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG', 'G_KL', 'D_real', 'D_fake', 'Smooth_L1')
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
print("---------- Optimizers initialized -------------")
if opt.continue_train:
self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch)
self.load_optimizer(self.optimizer_G, "G", opt.which_epoch)
for param_groups in self.optimizer_D.param_groups:
self.old_lr=param_groups['lr']
print("---------- Optimizers reloaded -------------")
print("---------- Current LR is %.8f -------------"%(self.old_lr))
## We also want to re-load the parameters of optimizer.
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
if self.opt.label_nc == 0:
input_label = label_map.data.cuda()
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
# instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
def discriminate(self, input_label, test_image, use_pool=False):
if input_label is None:
input_concat = test_image.detach()
else:
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(self, label, inst, image, feat, infer=False):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
if self.use_features:
if not self.opt.load_features:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
hiddens = self.netG.forward(input_concat, 'enc')
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
# This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
# We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py)
fake_image = self.netG.forward(hiddens + noise, 'dec')
if self.opt.no_cgan:
# Fake Detection and Loss
pred_fake_pool = self.discriminate(None, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(None, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(fake_image)
loss_G_GAN = self.criterionGAN(pred_fake, True)
else:
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
smooth_l1_loss=0
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,smooth_l1_loss ), None if not infer else fake_image ]
def inference(self, label, inst, image=None, feat=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
# Fake Generation
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
if label in features_clustered:
feat = features_clustered[label]
cluster_idx = np.random.randint(0, feat.shape[0])
idx = (inst == int(i)).nonzero()
for k in range(self.opt.feat_num):
feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
if self.opt.data_type==16:
feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num+1))
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num//2,:]
val = np.zeros((1, feat_num+1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
def get_edges(self, t):
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
if self.opt.data_type==16:
return edge.half()
else:
return edge.float()
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
self.save_optimizer(self.optimizer_G,"G",which_epoch)
self.save_optimizer(self.optimizer_D,"D",which_epoch)
if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
def update_fixed_params(self):
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label, inst = inp
return self.inference(label, inst)

View File

@@ -1,372 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake):
return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake), flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false
self.gen_features = self.use_features and not self.opt.load_features ## it is also false
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel #
##### define networks
# Generator network
netG_input_nc = input_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt,opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
self.feat_D=networks.define_D(64, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,
1, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
print("---------- G Networks reloaded -------------")
if self.isTrain:
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
self.load_network(self.feat_D, 'feat_D', opt.which_epoch, pretrained_path)
print("---------- D Networks reloaded -------------")
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0!
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionFeat = torch.nn.L1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_KL', 'D_real', 'D_fake', 'G_featD', 'featD_real','featD_fake')
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
params = list(self.feat_D.parameters())
self.optimizer_featD = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
print("---------- Optimizers initialized -------------")
if opt.continue_train:
self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch)
self.load_optimizer(self.optimizer_G, "G", opt.which_epoch)
self.load_optimizer(self.optimizer_featD,'featD',opt.which_epoch)
for param_groups in self.optimizer_D.param_groups:
self.old_lr = param_groups['lr']
print("---------- Optimizers reloaded -------------")
print("---------- Current LR is %.8f -------------" % (self.old_lr))
## We also want to re-load the parameters of optimizer.
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
if self.opt.label_nc == 0:
input_label = label_map.data.cuda()
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
# instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
def discriminate(self, input_label, test_image, use_pool=False):
if input_label is None:
input_concat = test_image.detach()
else:
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def feat_discriminate(self,input):
return self.feat_D.forward(input.detach())
def forward(self, label, inst, image, feat, infer=False):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
if self.use_features:
if not self.opt.load_features:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
hiddens = self.netG.forward(input_concat, 'enc')
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
# This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
# We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py)
fake_image = self.netG.forward(hiddens + noise, 'dec')
####################
##### GAN for the intermediate feature
real_old_feat =[]
syn_feat = []
for index,x in enumerate(inst):
if x==1:
real_old_feat.append(hiddens[index].unsqueeze(0))
else:
syn_feat.append(hiddens[index].unsqueeze(0))
L=min(len(real_old_feat),len(syn_feat))
real_old_feat=real_old_feat[:L]
syn_feat=syn_feat[:L]
real_old_feat=torch.cat(real_old_feat,0)
syn_feat=torch.cat(syn_feat,0)
pred_fake_feat=self.feat_discriminate(real_old_feat)
loss_featD_fake = self.criterionGAN(pred_fake_feat, False)
pred_real_feat=self.feat_discriminate(syn_feat)
loss_featD_real = self.criterionGAN(pred_real_feat, True)
pred_fake_feat_G=self.feat_D.forward(real_old_feat)
loss_G_featD=self.criterionGAN(pred_fake_feat_G,True)
#####################################
if self.opt.no_cgan:
# Fake Detection and Loss
pred_fake_pool = self.discriminate(None, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(None, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(fake_image)
loss_G_GAN = self.criterionGAN(pred_fake, True)
else:
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i]) - 1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j],
pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
# Only return the fake_B image if necessary to save BW
return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,loss_G_featD, loss_featD_real, loss_featD_fake),
None if not infer else fake_image]
def inference(self, label, inst, image=None, feat=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
# Fake Generation
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
for i in np.unique(inst_np):
label = i if i < 1000 else i // 1000
if label in features_clustered:
feat = features_clustered[label]
cluster_idx = np.random.randint(0, feat.shape[0])
idx = (inst == int(i)).nonzero()
for k in range(self.opt.feat_num):
feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k]
if self.opt.data_type == 16:
feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num + 1))
for i in np.unique(inst_np):
label = i if i < 1000 else i // 1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num // 2, :]
val = np.zeros((1, feat_num + 1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
def get_edges(self, t):
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
if self.opt.data_type == 16:
return edge.half()
else:
return edge.float()
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
self.save_network(self.feat_D,'featD',which_epoch,self.gpu_ids)
self.save_optimizer(self.optimizer_G, "G", which_epoch)
self.save_optimizer(self.optimizer_D, "D", which_epoch)
self.save_optimizer(self.optimizer_featD,'featD',which_epoch)
if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
def update_fixed_params(self):
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_featD.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label, inst = inp
return self.inference(label, inst)

View File

@@ -1,469 +0,0 @@
# Copyright (c) Microsoft Corporation
import argparse
import torch
class BaseOptions:
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
# experiment specifics
self.parser.add_argument(
"--name",
type=str,
default="label2city",
help="name of the experiment. It decides where to store samples and models",
)
self.parser.add_argument(
"--gpu_ids",
type=str,
default="0",
help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU",
)
self.parser.add_argument(
"--checkpoints_dir",
type=str,
default="./checkpoints",
help="models are saved here",
) ## note: to add this param when using philly
# self.parser.add_argument('--project_dir', type=str, default='./', help='the project is saved here') ################### This is necessary for philly
self.parser.add_argument(
"--outputs_dir", type=str, default="./outputs", help="models are saved here"
) ## note: to add this param when using philly Please end with '/'
self.parser.add_argument(
"--model", type=str, default="pix2pixHD", help="which model to use"
)
self.parser.add_argument(
"--norm",
type=str,
default="instance",
help="instance normalization or batch normalization",
)
self.parser.add_argument(
"--use_dropout", action="store_true", help="use dropout for the generator"
)
self.parser.add_argument(
"--data_type",
default=32,
type=int,
choices=[8, 16, 32],
help="Supported data type i.e. 8, 16, 32 bit",
)
self.parser.add_argument(
"--verbose", action="store_true", default=False, help="toggles verbose"
)
# input/output sizes
self.parser.add_argument(
"--batchSize", type=int, default=1, help="input batch size"
)
self.parser.add_argument(
"--loadSize", type=int, default=1024, help="scale images to this size"
)
self.parser.add_argument(
"--fineSize", type=int, default=512, help="then crop to this size"
)
self.parser.add_argument(
"--label_nc", type=int, default=35, help="# of input label channels"
)
self.parser.add_argument(
"--input_nc", type=int, default=3, help="# of input image channels"
)
self.parser.add_argument(
"--output_nc", type=int, default=3, help="# of output image channels"
)
# for setting inputs
self.parser.add_argument(
"--dataroot", type=str, default="./datasets/cityscapes/"
)
self.parser.add_argument(
"--resize_or_crop",
type=str,
default="scale_width",
help="scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]",
)
self.parser.add_argument(
"--serial_batches",
action="store_true",
help="if true, takes images in order to make batches, otherwise takes them randomly",
)
self.parser.add_argument(
"--no_flip",
action="store_true",
help="if specified, do not flip the images for data argumentation",
)
self.parser.add_argument(
"--nThreads", default=2, type=int, help="# threads for loading data"
)
self.parser.add_argument(
"--max_dataset_size",
type=int,
default=float("inf"),
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
)
# for displays
self.parser.add_argument(
"--display_winsize", type=int, default=512, help="display window size"
)
self.parser.add_argument(
"--tf_log",
action="store_true",
help="if specified, use tensorboard logging. Requires tensorflow installed",
)
# for generator
self.parser.add_argument(
"--netG", type=str, default="global", help="selects model to use for netG"
)
self.parser.add_argument(
"--ngf", type=int, default=64, help="# of gen filters in first conv layer"
)
self.parser.add_argument(
"--k_size", type=int, default=3, help="# kernel size conv layer"
)
self.parser.add_argument("--use_v2", action="store_true", help="use DCDCv2")
self.parser.add_argument("--mc", type=int, default=1024, help="# max channel")
self.parser.add_argument(
"--start_r", type=int, default=3, help="start layer to use resblock"
)
self.parser.add_argument(
"--n_downsample_global",
type=int,
default=4,
help="number of downsampling layers in netG",
)
self.parser.add_argument(
"--n_blocks_global",
type=int,
default=9,
help="number of residual blocks in the global generator network",
)
self.parser.add_argument(
"--n_blocks_local",
type=int,
default=3,
help="number of residual blocks in the local enhancer network",
)
self.parser.add_argument(
"--n_local_enhancers",
type=int,
default=1,
help="number of local enhancers to use",
)
self.parser.add_argument(
"--niter_fix_global",
type=int,
default=0,
help="number of epochs that we only train the outmost local enhancer",
)
self.parser.add_argument(
"--load_pretrain",
type=str,
default="",
help="load the pretrained model from the specified location",
)
# for instance-wise features
self.parser.add_argument(
"--no_instance",
action="store_true",
help="if specified, do *not* add instance map as input",
)
self.parser.add_argument(
"--instance_feat",
action="store_true",
help="if specified, add encoded instance features as input",
)
self.parser.add_argument(
"--label_feat",
action="store_true",
help="if specified, add encoded label features as input",
)
self.parser.add_argument(
"--feat_num", type=int, default=3, help="vector length for encoded features"
)
self.parser.add_argument(
"--load_features",
action="store_true",
help="if specified, load precomputed feature maps",
)
self.parser.add_argument(
"--n_downsample_E",
type=int,
default=4,
help="# of downsampling layers in encoder",
)
self.parser.add_argument(
"--nef",
type=int,
default=16,
help="# of encoder filters in the first conv layer",
)
self.parser.add_argument(
"--n_clusters", type=int, default=10, help="number of clusters for features"
)
# diy
self.parser.add_argument(
"--self_gen", action="store_true", help="self generate"
)
self.parser.add_argument(
"--mapping_n_block",
type=int,
default=3,
help="number of resblock in mapping",
)
self.parser.add_argument(
"--map_mc", type=int, default=64, help="max channel of mapping"
)
self.parser.add_argument("--kl", type=float, default=0, help="KL Loss")
self.parser.add_argument(
"--load_pretrainA",
type=str,
default="",
help="load the pretrained model from the specified location",
)
self.parser.add_argument(
"--load_pretrainB",
type=str,
default="",
help="load the pretrained model from the specified location",
)
self.parser.add_argument("--feat_gan", action="store_true")
self.parser.add_argument("--no_cgan", action="store_true")
self.parser.add_argument("--map_unet", action="store_true")
self.parser.add_argument("--map_densenet", action="store_true")
self.parser.add_argument("--fcn", action="store_true")
self.parser.add_argument(
"--is_image", action="store_true", help="train image recon only pair data"
)
self.parser.add_argument("--label_unpair", action="store_true")
self.parser.add_argument("--mapping_unpair", action="store_true")
self.parser.add_argument("--unpair_w", type=float, default=1.0)
self.parser.add_argument("--pair_num", type=int, default=-1)
self.parser.add_argument("--Gan_w", type=float, default=1)
self.parser.add_argument("--feat_dim", type=int, default=-1)
self.parser.add_argument("--abalation_vae_len", type=int, default=-1)
######################### useless, just to cooperate with docker
self.parser.add_argument("--gpu", type=str)
self.parser.add_argument("--dataDir", type=str)
self.parser.add_argument("--modelDir", type=str)
self.parser.add_argument("--logDir", type=str)
self.parser.add_argument("--data_dir", type=str)
self.parser.add_argument("--use_skip_model", action="store_true")
self.parser.add_argument("--use_segmentation_model", action="store_true")
self.parser.add_argument("--spatio_size", type=int, default=64)
self.parser.add_argument("--test_random_crop", action="store_true")
##########################
self.parser.add_argument("--contain_scratch_L", action="store_true")
self.parser.add_argument(
"--mask_dilation", type=int, default=0
) ## Don't change the input, only dilation the mask
self.parser.add_argument(
"--irregular_mask",
type=str,
default="",
help="This is the root of the mask",
)
self.parser.add_argument(
"--mapping_net_dilation",
type=int,
default=1,
help="This parameter is the dilation size of the translation net",
)
self.parser.add_argument(
"--VOC",
type=str,
default="VOC_RGB_JPEGImages.bigfile",
help="The root of VOC dataset",
)
self.parser.add_argument(
"--non_local", type=str, default="", help="which non_local setting"
)
self.parser.add_argument(
"--NL_fusion_method",
type=str,
default="add",
help="how to fuse the origin feature and nl feature",
)
self.parser.add_argument(
"--NL_use_mask",
action="store_true",
help="If use mask while using Non-local mapping model",
)
self.parser.add_argument(
"--correlation_renormalize",
action="store_true",
help="Since after mask out the correlation matrix(which is softmaxed), the sum is not 1 any more, enable this param to re-weight",
)
self.parser.add_argument(
"--Smooth_L1", action="store_true", help="Use L1 Loss in image level"
)
self.parser.add_argument(
"--face_restore_setting",
type=int,
default=1,
help="This is for the aligned face restoration",
)
self.parser.add_argument("--face_clean_url", type=str, default="")
self.parser.add_argument("--syn_input_url", type=str, default="")
self.parser.add_argument("--syn_gt_url", type=str, default="")
self.parser.add_argument(
"--test_on_synthetic",
action="store_true",
help="If you want to test on the synthetic data, enable this parameter",
)
self.parser.add_argument(
"--use_SN", action="store_true", help="Add SN to every parametric layer"
)
self.parser.add_argument(
"--use_two_stage_mapping",
action="store_true",
help="choose the model which uses two stage",
)
self.parser.add_argument("--L1_weight", type=float, default=10.0)
self.parser.add_argument("--softmax_temperature", type=float, default=1.0)
self.parser.add_argument(
"--patch_similarity",
action="store_true",
help="Enable this denotes using 3*3 patch to calculate similarity",
)
self.parser.add_argument(
"--use_self",
action="store_true",
help="Enable this denotes that while constructing the new feature maps, using original feature (diagonal == 1)",
)
self.parser.add_argument("--use_own_dataset", action="store_true")
self.parser.add_argument(
"--test_hole_two_folders",
action="store_true",
help="Enable this parameter means test the restoration with inpainting given twp folders which are mask and old respectively",
)
self.parser.add_argument(
"--no_hole",
action="store_true",
help="While test the full_model on non_scratch data, do not add random mask into the real old photos",
) ## Only for testing
self.parser.add_argument(
"--random_hole",
action="store_true",
help="While training the full model, 50% probability add hole",
)
self.parser.add_argument(
"--NL_res", action="store_true", help="NL+Resdual Block"
)
self.parser.add_argument(
"--image_L1", action="store_true", help="Image level loss: L1"
)
self.parser.add_argument(
"--hole_image_no_mask",
action="store_true",
help="while testing, give hole image but not give the mask",
)
self.parser.add_argument(
"--down_sample_degradation",
action="store_true",
help="down_sample the image only, corresponds to [down_sample_face]",
)
self.parser.add_argument(
"--norm_G",
type=str,
default="spectralinstance",
help="The norm type of Generator",
)
self.parser.add_argument(
"--init_G",
type=str,
default="xavier",
help="normal|xavier|xavier_uniform|kaiming|orthogonal|none",
)
self.parser.add_argument("--use_new_G", action="store_true")
self.parser.add_argument("--use_new_D", action="store_true")
self.parser.add_argument(
"--only_voc",
action="store_true",
help="test the trianed celebA face model using VOC face",
)
self.parser.add_argument(
"--cosin_similarity",
action="store_true",
help="For non-local, using cosin to calculate the similarity",
)
self.parser.add_argument(
"--downsample_mode",
type=str,
default="nearest",
help="For partial non-local, choose how to downsample the mask",
)
self.parser.add_argument(
"--mapping_exp",
type=int,
default=0,
help="Default 0: original PNL|1: Multi-Scale Patch Attention",
)
self.parser.add_argument(
"--inference_optimize", action="store_true", help="optimize the memory cost"
)
self.initialized = True
def parse(self, custom_args: list):
assert len(custom_args) > 0, "Manually Pass Arguments!"
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args(custom_args)
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(",")
self.opt.gpu_ids = []
for str_id in str_ids:
int_id = int(str_id)
if int_id >= 0:
self.opt.gpu_ids.append(int_id)
# set gpu ids
if torch.cuda.is_available() and len(self.opt.gpu_ids) > 0:
if torch.cuda.device_count() > self.opt.gpu_ids[0]:
try:
torch.cuda.set_device(self.opt.gpu_ids[0])
except:
print("Failed to set GPU device. Using CPU...")
else:
print("Invalid GPU ID. Using CPU...")
return self.opt

View File

@@ -1,100 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument("--ntest", type=int, default=float("inf"), help="# of test examples.")
self.parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.")
self.parser.add_argument(
"--aspect_ratio", type=float, default=1.0, help="aspect ratio of result images"
)
self.parser.add_argument("--phase", type=str, default="test", help="train, val, test, etc")
self.parser.add_argument(
"--which_epoch",
type=str,
default="latest",
help="which epoch to load? set to latest to use latest cached model",
)
self.parser.add_argument("--how_many", type=int, default=50, help="how many test images to run")
self.parser.add_argument(
"--cluster_path",
type=str,
default="features_clustered_010.npy",
help="the path for clustered results of encoded features",
)
self.parser.add_argument(
"--use_encoded_image",
action="store_true",
help="if specified, encode the real image to get the feature map",
)
self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
self.parser.add_argument(
"--start_epoch",
type=int,
default=-1,
help="write the start_epoch of iter.txt into this parameter",
)
self.parser.add_argument("--test_dataset", type=str, default="Real_RGB_old.bigfile")
self.parser.add_argument(
"--no_degradation",
action="store_true",
help="when train the mapping, enable this parameter --> no degradation will be added into clean image",
)
self.parser.add_argument(
"--no_load_VAE",
action="store_true",
help="when train the mapping, enable this parameter --> random initialize the encoder an decoder",
)
self.parser.add_argument(
"--use_v2_degradation",
action="store_true",
help="enable this parameter --> 4 kinds of degradations will be used to synthesize corruption",
)
self.parser.add_argument("--use_vae_which_epoch", type=str, default="latest")
self.isTrain = False
self.parser.add_argument("--generate_pair", action="store_true")
self.parser.add_argument("--multi_scale_test", type=float, default=0.5)
self.parser.add_argument("--multi_scale_threshold", type=float, default=0.5)
self.parser.add_argument(
"--mask_need_scale",
action="store_true",
help="enable this param meas that the pixel range of mask is 0-255",
)
self.parser.add_argument("--scale_num", type=int, default=1)
self.parser.add_argument(
"--save_feature_url", type=str, default="", help="While extracting the features, where to put"
)
self.parser.add_argument(
"--test_input", type=str, default="", help="A directory or a root of bigfile"
)
self.parser.add_argument("--test_mask", type=str, default="", help="A directory or a root of bigfile")
self.parser.add_argument("--test_gt", type=str, default="", help="A directory or a root of bigfile")
self.parser.add_argument(
"--scale_input", action="store_true", help="While testing, choose to scale the input firstly"
)
self.parser.add_argument(
"--save_feature_name", type=str, default="features.json", help="The name of saved features"
)
self.parser.add_argument(
"--test_rgb_old_wo_scratch", action="store_true", help="Same setting with origin test"
)
self.parser.add_argument("--test_mode", type=str, default="Crop", help="Scale|Full|Crop")
self.parser.add_argument("--Quality_restore", action="store_true", help="For RGB images")
self.parser.add_argument(
"--Scratch_and_Quality_restore", action="store_true", help="For scratched images"
)
self.parser.add_argument("--HR", action='store_true',help='Large input size with scratches')

View File

@@ -1,142 +0,0 @@
# Copyright (c) Microsoft Corporation
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import torch
import cv2
import os
from .models.mapping_model import Pix2PixHDModel_Mapping
from .options.test_options import TestOptions
tensor2image = transforms.ToPILImage()
def data_transforms(img, method=Image.BILINEAR, scale=False):
ow, oh = img.size
pw, ph = ow, oh
if scale == True:
if ow < oh:
ow = 256
oh = ph / pw * 256
else:
oh = 256
ow = pw / ph * 256
h = int(round(oh / 4) * 4)
w = int(round(ow / 4) * 4)
if (h == ph) and (w == pw):
return img
return img.resize((w, h), method)
def data_transforms_rgb_old(img):
w, h = img.size
A = img
if w < 256 or h < 256:
A = transforms.Scale(256, Image.BILINEAR)(img)
return transforms.CenterCrop(256)(A)
def irregular_hole_synthesize(img, mask):
img_np = np.array(img).astype("uint8")
mask_np = np.array(mask).astype("uint8")
mask_np = mask_np / 255
img_new = img_np * (1 - mask_np) + mask_np * 255
hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB")
return hole_img
def parameter_set(opt, ckpt_dir):
## Default parameters
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
opt.label_nc = 0
opt.n_downsample_global = 3
opt.mc = 64
opt.k_size = 4
opt.start_r = 1
opt.mapping_n_block = 6
opt.map_mc = 512
opt.no_instance = True
opt.checkpoints_dir = ckpt_dir
if opt.Quality_restore:
opt.name = "mapping_quality"
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality")
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality")
if opt.Scratch_and_Quality_restore:
opt.NL_res = True
opt.use_SN = True
opt.correlation_renormalize = True
opt.NL_use_mask = True
opt.NL_fusion_method = "combine"
opt.non_local = "Setting_42"
opt.name = "mapping_scratch"
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality")
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch")
if opt.HR:
opt.mapping_exp = 1
opt.inference_optimize = True
opt.mask_dilation = 3
opt.name = "mapping_Patch_Attention"
def global_test(
ckpt_dir: str, custom_args: list, input_image: Image, mask: Image = None
) -> Image:
opt = TestOptions().parse(custom_args)
parameter_set(opt, ckpt_dir)
model = Pix2PixHDModel_Mapping()
model.initialize(opt)
model.eval()
img_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
mask_transform = transforms.ToTensor()
print("processing...")
if opt.NL_use_mask:
if opt.mask_dilation != 0:
kernel = np.ones((3, 3), np.uint8)
mask = np.array(mask)
mask = cv2.dilate(mask, kernel, iterations=opt.mask_dilation)
mask = Image.fromarray(mask.astype("uint8"))
input_image = irregular_hole_synthesize(input_image, mask)
mask = mask_transform(mask)
mask = mask[:1, :, :] ## Convert to single channel
mask = mask.unsqueeze(0)
input_image = img_transform(input_image)
input_image = input_image.unsqueeze(0)
else:
if opt.test_mode == "Scale":
input_image = data_transforms(input_image, scale=True)
elif opt.test_mode == "Full":
input_image = data_transforms(input_image, scale=False)
elif opt.test_mode == "Crop":
input_image = data_transforms_rgb_old(input_image)
input_image = img_transform(input_image)
input_image = input_image.unsqueeze(0)
mask = torch.zeros_like(input_image)
with torch.no_grad():
generated = model.inference(input_image, mask)
restored = torch.clamp((generated.data.cpu() + 1.0) / 2.0, 0.0, 1.0) * 255
return tensor2image(restored[0].byte())

View File

@@ -1,36 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import random
import torch
from torch.autograd import Variable
class ImagePool:
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
return_images = []
for image in images.data:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return return_images

View File

@@ -1,58 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import numpy as np
import os
import torch.nn as nn
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
image_numpy = image_tensor.cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
image_numpy = image_numpy[:, :, 0]
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8):
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
return label_numpy.astype(imtype)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE

View File

@@ -14,11 +14,11 @@ This is an Extension for the [Automatic1111 Webui](https://github.com/AUTOMATIC1
## Requirements
0. Install this Extension
1. Download `global_checkpoints.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases)
2. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/Global`
2. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Global`
3. Download `face_checkpoints.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases)
4. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/Face_Enhancement`
4. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Face_Enhancement`
5. Download `shape_predictor_68_face_landmarks.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases)
6. Extract the `.dat` **file** into `~webui/extensions/sd-webui-old-photo-restoration/Face_Detection`
6. Extract the `.dat` **file** into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Face_Detection`
> The [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases) page includes the original links, as well as the backups mirrored by myself
@@ -40,3 +40,4 @@ After installing this Extension, there will be an **Old Photo Restoration** sect
- *(This is **different** from the Webui built-in ones)*
- **High Resolution:** Use higher parameters to do the processing
- *(Only has an effect when either `Process Scratch` or `Face Restore` is also enabled)*
- **Use CPU:** Enable this if you do not have a Nvidia GPU or are getting **Out of Memory** Error

1
lib_bopb2l Submodule

Submodule lib_bopb2l added at b5b24d5475

View File

@@ -2,14 +2,28 @@ import os
EXTENSION_FOLDER = os.path.dirname(os.path.realpath(__file__))
FACE = os.path.join(EXTENSION_FOLDER, "Face_Enhancement", "checkpoints")
FACE = os.path.join(
EXTENSION_FOLDER,
"lib_bopb2l",
"Face_Enhancement",
"checkpoints",
)
GLOBAL = os.path.join(
EXTENSION_FOLDER,
"lib_bopb2l",
"Global",
"checkpoints",
)
MDL = os.path.join(
EXTENSION_FOLDER,
"lib_bopb2l",
"Face_Detection",
"shape_predictor_68_face_landmarks.dat",
)
if not os.path.exists(FACE):
print("\n[Warning] face_checkpoints not detected! Please download it from Release!")
GLOBAL = os.path.join(EXTENSION_FOLDER, "Global", "checkpoints")
if not os.path.exists(GLOBAL):
print("[Warning] global_checkpoints not detected! Please download it from Release!")
MDL = os.path.join(EXTENSION_FOLDER, "Face_Detection", "shape_predictor_68_face_landmarks.dat")
if not os.path.exists(MDL):
print("[Warning] face_landmarks not detected! Please download it from Release!\n")

View File

@@ -1,10 +1,10 @@
scikit-image
easydict
PyYAML
dominate
dill
tensorboardX
scipy
opencv-python
dominate
easydict
einops
matplotlib
opencv-python
scikit-image
scipy
tensorboardX

View File

@@ -5,12 +5,12 @@ from modules.api import api
from fastapi import FastAPI, Body
import gradio as gr
from scripts.main_function import main
from scripts.bopb2l_main import main
def bop_api(_: gr.Blocks, app: FastAPI):
@app.post("/bop")
@app.post("/bopb2l/restore")
async def old_photo_restoration(
image: str = Body("", title="input image"),
scratch: bool = Body(False, title="process scratch"),
@@ -22,9 +22,7 @@ def bop_api(_: gr.Blocks, app: FastAPI):
return
input_image = api.decode_base64_to_image(image)
img = main(input_image, scratch, hr, face_res, cpu)
return {"image": api.encode_pil_to_base64(img).decode("utf-8")}

View File

@@ -1,7 +1,8 @@
from modules import scripts_postprocessing, ui_components
from modules.ui_components import InputAccordion
from modules import scripts_postprocessing
import gradio as gr
from scripts.main_function import main
from scripts.bopb2l_main import main
class OldPhotoRestoration(scripts_postprocessing.ScriptPostprocessing):
@@ -9,22 +10,18 @@ class OldPhotoRestoration(scripts_postprocessing.ScriptPostprocessing):
order = 200409484
def ui(self):
with ui_components.InputAccordion(
False, label="Old Photo Restoration"
) as enable:
with InputAccordion(False, label="Old Photo Restoration") as enable:
proc_order = gr.Radio(
["Restoration First", "Upscale First"],
label="Processing Order",
choices=("Restoration First", "Upscale First"),
value="Restoration First",
label="Processing Order",
)
with gr.Row():
do_scratch = gr.Checkbox(label="Process Scratch")
do_face_res = gr.Checkbox(label="Face Restore")
do_scratch = gr.Checkbox(False, label="Process Scratch")
do_face_res = gr.Checkbox(False, label="Face Restore")
with gr.Row():
is_hr = gr.Checkbox(label="High Resolution")
is_hr = gr.Checkbox(False, label="High Resolution")
use_cpu = gr.Checkbox(True, label="Use CPU")
args = {

View File

@@ -1,12 +1,12 @@
from Global.test import global_test
from Global.detection import global_detection
from lib_bopb2l.Global.test import global_test
from lib_bopb2l.Global.detection import global_detection
from Face_Detection.detect_all_dlib import detect
from Face_Detection.detect_all_dlib_HR import detect_hr
from Face_Detection.align_warp_back_multiple_dlib import align_warp
from Face_Detection.align_warp_back_multiple_dlib_HR import align_warp_hr
from lib_bopb2l.Face_Detection.detect_all_dlib import detect
from lib_bopb2l.Face_Detection.detect_all_dlib_HR import detect_hr
from lib_bopb2l.Face_Detection.align_warp_back_multiple_dlib import align_warp
from lib_bopb2l.Face_Detection.align_warp_back_multiple_dlib_HR import align_warp_hr
from Face_Enhancement.test_face import test_face
from lib_bopb2l.Face_Enhancement.test_face import test_face
from modules import scripts
from PIL import Image
@@ -15,11 +15,12 @@ import os
GLOBAL_CHECKPOINTS_FOLDER = os.path.join(
scripts.basedir(), "Global", "checkpoints", "restoration"
scripts.basedir(), "lib_bopb2l", "Global", "checkpoints", "restoration"
)
FACE_CHECKPOINTS_FOLDER = os.path.join(
scripts.basedir(), "Face_Enhancement", "checkpoints"
scripts.basedir(), "lib_bopb2l", "Face_Enhancement", "checkpoints"
)
FACE_ENHANCEMENT_CHECKPOINTS = ("Setting_9_epoch_100", "FaceSR_512")