mirror of
https://github.com/Haoming02/sd-webui-old-photo-restoration.git
synced 2026-01-26 19:29:52 +00:00
submodule
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@@ -1,8 +1 @@
|
||||
# Junks
|
||||
__pycache__
|
||||
*.pyc
|
||||
*~
|
||||
|
||||
# Models
|
||||
*landmarks.dat
|
||||
*/checkpoints/*
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "lib_bopb2l"]
|
||||
path = lib_bopb2l
|
||||
url = https://github.com/Haoming02/BOP-B2L-Backend
|
||||
@@ -1,430 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
from skimage.transform import SimilarityTransform
|
||||
from matplotlib.patches import Rectangle
|
||||
from PIL import Image, ImageFilter
|
||||
from skimage.transform import warp
|
||||
from skimage import img_as_ubyte
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import dlib
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
def calculate_cdf(histogram):
|
||||
"""
|
||||
This method calculates the cumulative distribution function
|
||||
:param array histogram: The values of the histogram
|
||||
:return: normalized_cdf: The normalized cumulative distribution function
|
||||
:rtype: array
|
||||
"""
|
||||
# Get the cumulative sum of the elements
|
||||
cdf = histogram.cumsum()
|
||||
|
||||
# Normalize the cdf
|
||||
normalized_cdf = cdf / float(cdf.max())
|
||||
|
||||
return normalized_cdf
|
||||
|
||||
|
||||
def calculate_lookup(src_cdf, ref_cdf):
|
||||
"""
|
||||
This method creates the lookup table
|
||||
:param array src_cdf: The cdf for the source image
|
||||
:param array ref_cdf: The cdf for the reference image
|
||||
:return: lookup_table: The lookup table
|
||||
:rtype: array
|
||||
"""
|
||||
lookup_table = np.zeros(256)
|
||||
lookup_val = 0
|
||||
for src_pixel_val in range(len(src_cdf)):
|
||||
lookup_val
|
||||
for ref_pixel_val in range(len(ref_cdf)):
|
||||
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
|
||||
lookup_val = ref_pixel_val
|
||||
break
|
||||
lookup_table[src_pixel_val] = lookup_val
|
||||
return lookup_table
|
||||
|
||||
|
||||
def match_histograms(src_image, ref_image):
|
||||
"""
|
||||
This method matches the source image histogram to the
|
||||
reference signal
|
||||
:param image src_image: The original source image
|
||||
:param image ref_image: The reference image
|
||||
:return: image_after_matching
|
||||
:rtype: image (array)
|
||||
"""
|
||||
# Split the images into the different color channels
|
||||
# b means blue, g means green and r means red
|
||||
src_b, src_g, src_r = cv2.split(src_image)
|
||||
ref_b, ref_g, ref_r = cv2.split(ref_image)
|
||||
|
||||
# Compute the b, g, and r histograms separately
|
||||
# The flatten() Numpy method returns a copy of the array c
|
||||
# collapsed into one dimension.
|
||||
src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
|
||||
src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
|
||||
src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
|
||||
ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
|
||||
ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
|
||||
ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
|
||||
|
||||
# Compute the normalized cdf for the source and reference image
|
||||
src_cdf_blue = calculate_cdf(src_hist_blue)
|
||||
src_cdf_green = calculate_cdf(src_hist_green)
|
||||
src_cdf_red = calculate_cdf(src_hist_red)
|
||||
ref_cdf_blue = calculate_cdf(ref_hist_blue)
|
||||
ref_cdf_green = calculate_cdf(ref_hist_green)
|
||||
ref_cdf_red = calculate_cdf(ref_hist_red)
|
||||
|
||||
# Make a separate lookup table for each color
|
||||
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
|
||||
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
|
||||
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
|
||||
|
||||
# Use the lookup function to transform the colors of the original
|
||||
# source image
|
||||
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
|
||||
green_after_transform = cv2.LUT(src_g, green_lookup_table)
|
||||
red_after_transform = cv2.LUT(src_r, red_lookup_table)
|
||||
|
||||
# Put the image back together
|
||||
image_after_matching = cv2.merge(
|
||||
[blue_after_transform, green_after_transform, red_after_transform]
|
||||
)
|
||||
image_after_matching = cv2.convertScaleAbs(image_after_matching)
|
||||
|
||||
return image_after_matching
|
||||
|
||||
|
||||
def _standard_face_pts():
|
||||
pts = (
|
||||
np.array(
|
||||
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
|
||||
np.float32,
|
||||
)
|
||||
/ 256.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
return np.reshape(pts, (5, 2))
|
||||
|
||||
|
||||
def _origin_face_pts():
|
||||
pts = np.array(
|
||||
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
|
||||
np.float32,
|
||||
)
|
||||
|
||||
return np.reshape(pts, (5, 2))
|
||||
|
||||
|
||||
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
||||
|
||||
std_pts = _standard_face_pts() # [-1,1]
|
||||
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
|
||||
|
||||
# print(target_pts)
|
||||
|
||||
h, w, c = img.shape
|
||||
if normalize == True:
|
||||
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
||||
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
||||
|
||||
# print(landmark)
|
||||
|
||||
affine = SimilarityTransform()
|
||||
|
||||
affine.estimate(target_pts, landmark)
|
||||
|
||||
return affine
|
||||
|
||||
|
||||
def compute_inverse_transformation_matrix(
|
||||
img, landmark, normalize, target_face_scale=1.0
|
||||
):
|
||||
|
||||
std_pts = _standard_face_pts() # [-1,1]
|
||||
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
|
||||
|
||||
# print(target_pts)
|
||||
|
||||
h, w, c = img.shape
|
||||
if normalize == True:
|
||||
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
||||
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
||||
|
||||
# print(landmark)
|
||||
|
||||
affine = SimilarityTransform()
|
||||
|
||||
affine.estimate(landmark, target_pts)
|
||||
|
||||
return affine
|
||||
|
||||
|
||||
def show_detection(image, box, landmark):
|
||||
plt.imshow(image)
|
||||
print(box[2] - box[0])
|
||||
plt.gca().add_patch(
|
||||
Rectangle(
|
||||
(box[1], box[0]),
|
||||
box[2] - box[0],
|
||||
box[3] - box[1],
|
||||
linewidth=1,
|
||||
edgecolor="r",
|
||||
facecolor="none",
|
||||
)
|
||||
)
|
||||
plt.scatter(landmark[0][0], landmark[0][1])
|
||||
plt.scatter(landmark[1][0], landmark[1][1])
|
||||
plt.scatter(landmark[2][0], landmark[2][1])
|
||||
plt.scatter(landmark[3][0], landmark[3][1])
|
||||
plt.scatter(landmark[4][0], landmark[4][1])
|
||||
plt.show()
|
||||
|
||||
|
||||
def affine2theta(affine, input_w, input_h, target_w, target_h):
|
||||
# param = np.linalg.inv(affine)
|
||||
param = affine
|
||||
theta = np.zeros([2, 3])
|
||||
theta[0, 0] = param[0, 0] * input_h / target_h
|
||||
theta[0, 1] = param[0, 1] * input_w / target_h
|
||||
theta[0, 2] = (
|
||||
2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w
|
||||
) / target_h - 1
|
||||
theta[1, 0] = param[1, 0] * input_h / target_w
|
||||
theta[1, 1] = param[1, 1] * input_w / target_w
|
||||
theta[1, 2] = (
|
||||
2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w
|
||||
) / target_w - 1
|
||||
return theta
|
||||
|
||||
|
||||
def blur_blending(im1, im2, mask):
|
||||
|
||||
mask = mask * 255.0
|
||||
|
||||
kernel = np.ones((10, 10), np.uint8)
|
||||
mask = cv2.erode(mask, kernel, iterations=1)
|
||||
|
||||
mask = Image.fromarray(mask.astype("uint8")).convert("L")
|
||||
im1 = Image.fromarray(im1.astype("uint8"))
|
||||
im2 = Image.fromarray(im2.astype("uint8"))
|
||||
|
||||
mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
|
||||
im = Image.composite(im1, im2, mask)
|
||||
|
||||
im = Image.composite(im, im2, mask_blur)
|
||||
|
||||
return np.array(im) / 255.0
|
||||
|
||||
|
||||
def blur_blending_cv2(im1, im2, mask):
|
||||
|
||||
mask = mask * 255.0
|
||||
|
||||
kernel = np.ones((9, 9), np.uint8)
|
||||
mask = cv2.erode(mask, kernel, iterations=3)
|
||||
|
||||
mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
|
||||
mask_blur /= 255.0
|
||||
|
||||
im = im1 * mask_blur + (1 - mask_blur) * im2
|
||||
|
||||
im /= 255.0
|
||||
im = np.clip(im, 0.0, 1.0)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
def Poisson_blending(im1, im2, mask):
|
||||
|
||||
# mask = 1 - mask
|
||||
mask = mask * 255.0
|
||||
kernel = np.ones((10, 10), np.uint8)
|
||||
mask = cv2.erode(mask, kernel, iterations=1)
|
||||
mask /= 255
|
||||
mask = 1 - mask
|
||||
mask *= 255
|
||||
|
||||
mask = mask[:, :, 0]
|
||||
width, height, channels = im1.shape
|
||||
center = (int(height / 2), int(width / 2))
|
||||
result = cv2.seamlessClone(
|
||||
im2.astype("uint8"),
|
||||
im1.astype("uint8"),
|
||||
mask.astype("uint8"),
|
||||
center,
|
||||
cv2.MIXED_CLONE,
|
||||
)
|
||||
|
||||
return result / 255.0
|
||||
|
||||
|
||||
def Poisson_B(im1, im2, mask, center):
|
||||
|
||||
mask = mask * 255.0
|
||||
|
||||
result = cv2.seamlessClone(
|
||||
im2.astype("uint8"),
|
||||
im1.astype("uint8"),
|
||||
mask.astype("uint8"),
|
||||
center,
|
||||
cv2.NORMAL_CLONE,
|
||||
)
|
||||
|
||||
return result / 255
|
||||
|
||||
|
||||
def seamless_clone(old_face, new_face, raw_mask):
|
||||
|
||||
height, width, _ = old_face.shape
|
||||
height = height // 2
|
||||
width = width // 2
|
||||
|
||||
y_indices, x_indices, _ = np.nonzero(raw_mask)
|
||||
y_crop = slice(np.min(y_indices), np.max(y_indices))
|
||||
x_crop = slice(np.min(x_indices), np.max(x_indices))
|
||||
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
|
||||
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
|
||||
|
||||
insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
|
||||
insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
|
||||
insertion_mask[insertion_mask != 0] = 255
|
||||
prior = np.rint(
|
||||
np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")
|
||||
).astype("uint8")
|
||||
# if np.sum(insertion_mask) == 0:
|
||||
n_mask = insertion_mask[1:-1, 1:-1, :]
|
||||
n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
|
||||
print(n_mask.shape)
|
||||
x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
|
||||
if w < 4 or h < 4:
|
||||
blended = prior
|
||||
else:
|
||||
blended = cv2.seamlessClone(
|
||||
insertion, # pylint: disable=no-member
|
||||
prior,
|
||||
insertion_mask,
|
||||
(x_center, y_center),
|
||||
cv2.NORMAL_CLONE,
|
||||
) # pylint: disable=no-member
|
||||
|
||||
blended = blended[height:-height, width:-width]
|
||||
|
||||
return blended.astype("float32") / 255.0
|
||||
|
||||
|
||||
def get_landmark(face_landmarks, id):
|
||||
part = face_landmarks.part(id)
|
||||
x = part.x
|
||||
y = part.y
|
||||
|
||||
return (x, y)
|
||||
|
||||
|
||||
def search(face_landmarks):
|
||||
|
||||
x1, y1 = get_landmark(face_landmarks, 36)
|
||||
x2, y2 = get_landmark(face_landmarks, 39)
|
||||
x3, y3 = get_landmark(face_landmarks, 42)
|
||||
x4, y4 = get_landmark(face_landmarks, 45)
|
||||
|
||||
x_nose, y_nose = get_landmark(face_landmarks, 30)
|
||||
|
||||
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
|
||||
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
|
||||
|
||||
x_left_eye = int((x1 + x2) / 2)
|
||||
y_left_eye = int((y1 + y2) / 2)
|
||||
x_right_eye = int((x3 + x4) / 2)
|
||||
y_right_eye = int((y3 + y4) / 2)
|
||||
|
||||
results = np.array(
|
||||
[
|
||||
[x_left_eye, y_left_eye],
|
||||
[x_right_eye, y_right_eye],
|
||||
[x_nose, y_nose],
|
||||
[x_left_mouth, y_left_mouth],
|
||||
[x_right_mouth, y_right_mouth],
|
||||
]
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def align_warp(
|
||||
original_image: Image,
|
||||
restored_faces: list,
|
||||
) -> Image:
|
||||
|
||||
if len(restored_faces) == 0:
|
||||
return original_image
|
||||
|
||||
face_detector = dlib.get_frontal_face_detector()
|
||||
landmark = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"shape_predictor_68_face_landmarks.dat",
|
||||
)
|
||||
landmark_locator = dlib.shape_predictor(landmark)
|
||||
|
||||
origin_width, origin_height = original_image.size
|
||||
image = np.array(original_image)
|
||||
|
||||
faces = face_detector(image)
|
||||
|
||||
blended = image
|
||||
for face_id in range(len(faces)):
|
||||
|
||||
current_face = faces[face_id]
|
||||
face_landmarks = landmark_locator(image, current_face)
|
||||
current_fl = search(face_landmarks)
|
||||
|
||||
forward_mask = np.ones_like(image).astype("uint8")
|
||||
affine = compute_transformation_matrix(
|
||||
image, current_fl, False, target_face_scale=1.3
|
||||
)
|
||||
aligned_face = warp(
|
||||
image, affine, output_shape=(256, 256, 3), preserve_range=True
|
||||
)
|
||||
forward_mask = warp(
|
||||
forward_mask,
|
||||
affine,
|
||||
output_shape=(256, 256, 3),
|
||||
order=0,
|
||||
preserve_range=True,
|
||||
)
|
||||
|
||||
affine_inverse = affine.inverse
|
||||
cur_face = np.array(restored_faces[face_id])
|
||||
|
||||
## Histogram Color matching
|
||||
A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
|
||||
B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
|
||||
B = match_histograms(B, A)
|
||||
cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
|
||||
|
||||
warped_back = warp(
|
||||
cur_face,
|
||||
affine_inverse,
|
||||
output_shape=(origin_height, origin_width, 3),
|
||||
order=3,
|
||||
preserve_range=True,
|
||||
)
|
||||
|
||||
backward_mask = warp(
|
||||
forward_mask,
|
||||
affine_inverse,
|
||||
output_shape=(origin_height, origin_width, 3),
|
||||
order=0,
|
||||
preserve_range=True,
|
||||
) ## Nearest neighbour
|
||||
|
||||
blended = blur_blending_cv2(warped_back, blended, backward_mask)
|
||||
blended *= 255.0
|
||||
|
||||
return Image.fromarray(img_as_ubyte(blended / 255.0))
|
||||
@@ -1,430 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
from skimage.transform import SimilarityTransform
|
||||
from matplotlib.patches import Rectangle
|
||||
from PIL import Image, ImageFilter
|
||||
from skimage.transform import warp
|
||||
from skimage import img_as_ubyte
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import dlib
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
def calculate_cdf(histogram):
|
||||
"""
|
||||
This method calculates the cumulative distribution function
|
||||
:param array histogram: The values of the histogram
|
||||
:return: normalized_cdf: The normalized cumulative distribution function
|
||||
:rtype: array
|
||||
"""
|
||||
# Get the cumulative sum of the elements
|
||||
cdf = histogram.cumsum()
|
||||
|
||||
# Normalize the cdf
|
||||
normalized_cdf = cdf / float(cdf.max())
|
||||
|
||||
return normalized_cdf
|
||||
|
||||
|
||||
def calculate_lookup(src_cdf, ref_cdf):
|
||||
"""
|
||||
This method creates the lookup table
|
||||
:param array src_cdf: The cdf for the source image
|
||||
:param array ref_cdf: The cdf for the reference image
|
||||
:return: lookup_table: The lookup table
|
||||
:rtype: array
|
||||
"""
|
||||
lookup_table = np.zeros(256)
|
||||
lookup_val = 0
|
||||
for src_pixel_val in range(len(src_cdf)):
|
||||
lookup_val
|
||||
for ref_pixel_val in range(len(ref_cdf)):
|
||||
if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
|
||||
lookup_val = ref_pixel_val
|
||||
break
|
||||
lookup_table[src_pixel_val] = lookup_val
|
||||
return lookup_table
|
||||
|
||||
|
||||
def match_histograms(src_image, ref_image):
|
||||
"""
|
||||
This method matches the source image histogram to the
|
||||
reference signal
|
||||
:param image src_image: The original source image
|
||||
:param image ref_image: The reference image
|
||||
:return: image_after_matching
|
||||
:rtype: image (array)
|
||||
"""
|
||||
# Split the images into the different color channels
|
||||
# b means blue, g means green and r means red
|
||||
src_b, src_g, src_r = cv2.split(src_image)
|
||||
ref_b, ref_g, ref_r = cv2.split(ref_image)
|
||||
|
||||
# Compute the b, g, and r histograms separately
|
||||
# The flatten() Numpy method returns a copy of the array c
|
||||
# collapsed into one dimension.
|
||||
src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
|
||||
src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
|
||||
src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
|
||||
ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
|
||||
ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
|
||||
ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
|
||||
|
||||
# Compute the normalized cdf for the source and reference image
|
||||
src_cdf_blue = calculate_cdf(src_hist_blue)
|
||||
src_cdf_green = calculate_cdf(src_hist_green)
|
||||
src_cdf_red = calculate_cdf(src_hist_red)
|
||||
ref_cdf_blue = calculate_cdf(ref_hist_blue)
|
||||
ref_cdf_green = calculate_cdf(ref_hist_green)
|
||||
ref_cdf_red = calculate_cdf(ref_hist_red)
|
||||
|
||||
# Make a separate lookup table for each color
|
||||
blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
|
||||
green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
|
||||
red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
|
||||
|
||||
# Use the lookup function to transform the colors of the original
|
||||
# source image
|
||||
blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
|
||||
green_after_transform = cv2.LUT(src_g, green_lookup_table)
|
||||
red_after_transform = cv2.LUT(src_r, red_lookup_table)
|
||||
|
||||
# Put the image back together
|
||||
image_after_matching = cv2.merge(
|
||||
[blue_after_transform, green_after_transform, red_after_transform]
|
||||
)
|
||||
image_after_matching = cv2.convertScaleAbs(image_after_matching)
|
||||
|
||||
return image_after_matching
|
||||
|
||||
|
||||
def _standard_face_pts():
|
||||
pts = (
|
||||
np.array(
|
||||
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
|
||||
np.float32,
|
||||
)
|
||||
/ 256.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
return np.reshape(pts, (5, 2))
|
||||
|
||||
|
||||
def _origin_face_pts():
|
||||
pts = np.array(
|
||||
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
|
||||
np.float32,
|
||||
)
|
||||
|
||||
return np.reshape(pts, (5, 2))
|
||||
|
||||
|
||||
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
||||
|
||||
std_pts = _standard_face_pts() # [-1,1]
|
||||
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
|
||||
|
||||
# print(target_pts)
|
||||
|
||||
h, w, c = img.shape
|
||||
if normalize == True:
|
||||
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
||||
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
||||
|
||||
# print(landmark)
|
||||
|
||||
affine = SimilarityTransform()
|
||||
|
||||
affine.estimate(target_pts, landmark)
|
||||
|
||||
return affine
|
||||
|
||||
|
||||
def compute_inverse_transformation_matrix(
|
||||
img, landmark, normalize, target_face_scale=1.0
|
||||
):
|
||||
|
||||
std_pts = _standard_face_pts() # [-1,1]
|
||||
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
|
||||
|
||||
# print(target_pts)
|
||||
|
||||
h, w, c = img.shape
|
||||
if normalize == True:
|
||||
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
||||
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
||||
|
||||
# print(landmark)
|
||||
|
||||
affine = SimilarityTransform()
|
||||
|
||||
affine.estimate(landmark, target_pts)
|
||||
|
||||
return affine
|
||||
|
||||
|
||||
def show_detection(image, box, landmark):
|
||||
plt.imshow(image)
|
||||
print(box[2] - box[0])
|
||||
plt.gca().add_patch(
|
||||
Rectangle(
|
||||
(box[1], box[0]),
|
||||
box[2] - box[0],
|
||||
box[3] - box[1],
|
||||
linewidth=1,
|
||||
edgecolor="r",
|
||||
facecolor="none",
|
||||
)
|
||||
)
|
||||
plt.scatter(landmark[0][0], landmark[0][1])
|
||||
plt.scatter(landmark[1][0], landmark[1][1])
|
||||
plt.scatter(landmark[2][0], landmark[2][1])
|
||||
plt.scatter(landmark[3][0], landmark[3][1])
|
||||
plt.scatter(landmark[4][0], landmark[4][1])
|
||||
plt.show()
|
||||
|
||||
|
||||
def affine2theta(affine, input_w, input_h, target_w, target_h):
|
||||
# param = np.linalg.inv(affine)
|
||||
param = affine
|
||||
theta = np.zeros([2, 3])
|
||||
theta[0, 0] = param[0, 0] * input_h / target_h
|
||||
theta[0, 1] = param[0, 1] * input_w / target_h
|
||||
theta[0, 2] = (
|
||||
2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w
|
||||
) / target_h - 1
|
||||
theta[1, 0] = param[1, 0] * input_h / target_w
|
||||
theta[1, 1] = param[1, 1] * input_w / target_w
|
||||
theta[1, 2] = (
|
||||
2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w
|
||||
) / target_w - 1
|
||||
return theta
|
||||
|
||||
|
||||
def blur_blending(im1, im2, mask):
|
||||
|
||||
mask = mask * 255.0
|
||||
|
||||
kernel = np.ones((10, 10), np.uint8)
|
||||
mask = cv2.erode(mask, kernel, iterations=1)
|
||||
|
||||
mask = Image.fromarray(mask.astype("uint8")).convert("L")
|
||||
im1 = Image.fromarray(im1.astype("uint8"))
|
||||
im2 = Image.fromarray(im2.astype("uint8"))
|
||||
|
||||
mask_blur = mask.filter(ImageFilter.GaussianBlur(20))
|
||||
im = Image.composite(im1, im2, mask)
|
||||
|
||||
im = Image.composite(im, im2, mask_blur)
|
||||
|
||||
return np.array(im) / 255.0
|
||||
|
||||
|
||||
def blur_blending_cv2(im1, im2, mask):
|
||||
|
||||
mask = mask * 255.0
|
||||
|
||||
kernel = np.ones((9, 9), np.uint8)
|
||||
mask = cv2.erode(mask, kernel, iterations=3)
|
||||
|
||||
mask_blur = cv2.GaussianBlur(mask, (25, 25), 0)
|
||||
mask_blur /= 255.0
|
||||
|
||||
im = im1 * mask_blur + (1 - mask_blur) * im2
|
||||
|
||||
im /= 255.0
|
||||
im = np.clip(im, 0.0, 1.0)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
def Poisson_blending(im1, im2, mask):
|
||||
|
||||
# mask = 1 - mask
|
||||
mask = mask * 255.0
|
||||
kernel = np.ones((10, 10), np.uint8)
|
||||
mask = cv2.erode(mask, kernel, iterations=1)
|
||||
mask /= 255
|
||||
mask = 1 - mask
|
||||
mask *= 255
|
||||
|
||||
mask = mask[:, :, 0]
|
||||
width, height, channels = im1.shape
|
||||
center = (int(height / 2), int(width / 2))
|
||||
result = cv2.seamlessClone(
|
||||
im2.astype("uint8"),
|
||||
im1.astype("uint8"),
|
||||
mask.astype("uint8"),
|
||||
center,
|
||||
cv2.MIXED_CLONE,
|
||||
)
|
||||
|
||||
return result / 255.0
|
||||
|
||||
|
||||
def Poisson_B(im1, im2, mask, center):
|
||||
|
||||
mask = mask * 255.0
|
||||
|
||||
result = cv2.seamlessClone(
|
||||
im2.astype("uint8"),
|
||||
im1.astype("uint8"),
|
||||
mask.astype("uint8"),
|
||||
center,
|
||||
cv2.NORMAL_CLONE,
|
||||
)
|
||||
|
||||
return result / 255
|
||||
|
||||
|
||||
def seamless_clone(old_face, new_face, raw_mask):
|
||||
|
||||
height, width, _ = old_face.shape
|
||||
height = height // 2
|
||||
width = width // 2
|
||||
|
||||
y_indices, x_indices, _ = np.nonzero(raw_mask)
|
||||
y_crop = slice(np.min(y_indices), np.max(y_indices))
|
||||
x_crop = slice(np.min(x_indices), np.max(x_indices))
|
||||
y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height))
|
||||
x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width))
|
||||
|
||||
insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8")
|
||||
insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8")
|
||||
insertion_mask[insertion_mask != 0] = 255
|
||||
prior = np.rint(
|
||||
np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")
|
||||
).astype("uint8")
|
||||
# if np.sum(insertion_mask) == 0:
|
||||
n_mask = insertion_mask[1:-1, 1:-1, :]
|
||||
n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0)
|
||||
print(n_mask.shape)
|
||||
x, y, w, h = cv2.boundingRect(n_mask[:, :, 0])
|
||||
if w < 4 or h < 4:
|
||||
blended = prior
|
||||
else:
|
||||
blended = cv2.seamlessClone(
|
||||
insertion, # pylint: disable=no-member
|
||||
prior,
|
||||
insertion_mask,
|
||||
(x_center, y_center),
|
||||
cv2.NORMAL_CLONE,
|
||||
) # pylint: disable=no-member
|
||||
|
||||
blended = blended[height:-height, width:-width]
|
||||
|
||||
return blended.astype("float32") / 255.0
|
||||
|
||||
|
||||
def get_landmark(face_landmarks, id):
|
||||
part = face_landmarks.part(id)
|
||||
x = part.x
|
||||
y = part.y
|
||||
|
||||
return (x, y)
|
||||
|
||||
|
||||
def search(face_landmarks):
|
||||
|
||||
x1, y1 = get_landmark(face_landmarks, 36)
|
||||
x2, y2 = get_landmark(face_landmarks, 39)
|
||||
x3, y3 = get_landmark(face_landmarks, 42)
|
||||
x4, y4 = get_landmark(face_landmarks, 45)
|
||||
|
||||
x_nose, y_nose = get_landmark(face_landmarks, 30)
|
||||
|
||||
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
|
||||
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
|
||||
|
||||
x_left_eye = int((x1 + x2) / 2)
|
||||
y_left_eye = int((y1 + y2) / 2)
|
||||
x_right_eye = int((x3 + x4) / 2)
|
||||
y_right_eye = int((y3 + y4) / 2)
|
||||
|
||||
results = np.array(
|
||||
[
|
||||
[x_left_eye, y_left_eye],
|
||||
[x_right_eye, y_right_eye],
|
||||
[x_nose, y_nose],
|
||||
[x_left_mouth, y_left_mouth],
|
||||
[x_right_mouth, y_right_mouth],
|
||||
]
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def align_warp_hr(
|
||||
original_image: Image,
|
||||
restored_faces: list,
|
||||
) -> Image:
|
||||
|
||||
if len(restored_faces) == 0:
|
||||
return original_image
|
||||
|
||||
face_detector = dlib.get_frontal_face_detector()
|
||||
landmark = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"shape_predictor_68_face_landmarks.dat",
|
||||
)
|
||||
landmark_locator = dlib.shape_predictor(landmark)
|
||||
|
||||
origin_width, origin_height = original_image.size
|
||||
image = np.array(original_image)
|
||||
|
||||
faces = face_detector(image)
|
||||
|
||||
blended = image
|
||||
for face_id in range(len(faces)):
|
||||
|
||||
current_face = faces[face_id]
|
||||
face_landmarks = landmark_locator(image, current_face)
|
||||
current_fl = search(face_landmarks)
|
||||
|
||||
forward_mask = np.ones_like(image).astype("uint8")
|
||||
affine = compute_transformation_matrix(
|
||||
image, current_fl, False, target_face_scale=1.3
|
||||
)
|
||||
aligned_face = warp(
|
||||
image, affine, output_shape=(512, 512, 3), preserve_range=True
|
||||
)
|
||||
forward_mask = warp(
|
||||
forward_mask,
|
||||
affine,
|
||||
output_shape=(512, 512, 3),
|
||||
order=0,
|
||||
preserve_range=True,
|
||||
)
|
||||
|
||||
affine_inverse = affine.inverse
|
||||
cur_face = np.array(restored_faces[face_id])
|
||||
|
||||
## Histogram Color matching
|
||||
A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR)
|
||||
B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR)
|
||||
B = match_histograms(B, A)
|
||||
cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB)
|
||||
|
||||
warped_back = warp(
|
||||
cur_face,
|
||||
affine_inverse,
|
||||
output_shape=(origin_height, origin_width, 3),
|
||||
order=3,
|
||||
preserve_range=True,
|
||||
)
|
||||
|
||||
backward_mask = warp(
|
||||
forward_mask,
|
||||
affine_inverse,
|
||||
output_shape=(origin_height, origin_width, 3),
|
||||
order=0,
|
||||
preserve_range=True,
|
||||
) ## Nearest neighbour
|
||||
|
||||
blended = blur_blending_cv2(warped_back, blended, backward_mask)
|
||||
blended *= 255.0
|
||||
|
||||
return Image.fromarray(img_as_ubyte(blended / 255.0))
|
||||
@@ -1,153 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
from skimage.transform import SimilarityTransform
|
||||
from matplotlib.patches import Rectangle
|
||||
from skimage.transform import warp
|
||||
from skimage import img_as_ubyte
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import dlib
|
||||
import os
|
||||
|
||||
|
||||
def _standard_face_pts():
|
||||
pts = (
|
||||
np.array(
|
||||
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
|
||||
np.float32,
|
||||
)
|
||||
/ 256.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
return np.reshape(pts, (5, 2))
|
||||
|
||||
|
||||
def get_landmark(face_landmarks, id):
|
||||
part = face_landmarks.part(id)
|
||||
x = part.x
|
||||
y = part.y
|
||||
|
||||
return (x, y)
|
||||
|
||||
|
||||
def search(face_landmarks):
|
||||
|
||||
x1, y1 = get_landmark(face_landmarks, 36)
|
||||
x2, y2 = get_landmark(face_landmarks, 39)
|
||||
x3, y3 = get_landmark(face_landmarks, 42)
|
||||
x4, y4 = get_landmark(face_landmarks, 45)
|
||||
|
||||
x_nose, y_nose = get_landmark(face_landmarks, 30)
|
||||
|
||||
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
|
||||
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
|
||||
|
||||
x_left_eye = int((x1 + x2) / 2)
|
||||
y_left_eye = int((y1 + y2) / 2)
|
||||
x_right_eye = int((x3 + x4) / 2)
|
||||
y_right_eye = int((y3 + y4) / 2)
|
||||
|
||||
results = np.array(
|
||||
[
|
||||
[x_left_eye, y_left_eye],
|
||||
[x_right_eye, y_right_eye],
|
||||
[x_nose, y_nose],
|
||||
[x_left_mouth, y_left_mouth],
|
||||
[x_right_mouth, y_right_mouth],
|
||||
]
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
||||
|
||||
std_pts = _standard_face_pts() # [-1,1]
|
||||
target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0
|
||||
|
||||
# print(target_pts)
|
||||
|
||||
h, w, c = img.shape
|
||||
if normalize == True:
|
||||
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
||||
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
||||
|
||||
# print(landmark)
|
||||
|
||||
affine = SimilarityTransform()
|
||||
|
||||
affine.estimate(target_pts, landmark)
|
||||
|
||||
return affine.params
|
||||
|
||||
|
||||
def show_detection(image, box, landmark):
|
||||
plt.imshow(image)
|
||||
print(box[2] - box[0])
|
||||
plt.gca().add_patch(
|
||||
Rectangle(
|
||||
(box[1], box[0]),
|
||||
box[2] - box[0],
|
||||
box[3] - box[1],
|
||||
linewidth=1,
|
||||
edgecolor="r",
|
||||
facecolor="none",
|
||||
)
|
||||
)
|
||||
plt.scatter(landmark[0][0], landmark[0][1])
|
||||
plt.scatter(landmark[1][0], landmark[1][1])
|
||||
plt.scatter(landmark[2][0], landmark[2][1])
|
||||
plt.scatter(landmark[3][0], landmark[3][1])
|
||||
plt.scatter(landmark[4][0], landmark[4][1])
|
||||
plt.show()
|
||||
|
||||
|
||||
def affine2theta(affine, input_w, input_h, target_w, target_h):
|
||||
# param = np.linalg.inv(affine)
|
||||
param = affine
|
||||
theta = np.zeros([2, 3])
|
||||
theta[0, 0] = param[0, 0] * input_h / target_h
|
||||
theta[0, 1] = param[0, 1] * input_w / target_h
|
||||
theta[0, 2] = (
|
||||
2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w
|
||||
) / target_h - 1
|
||||
theta[1, 0] = param[1, 0] * input_h / target_w
|
||||
theta[1, 1] = param[1, 1] * input_w / target_w
|
||||
theta[1, 2] = (
|
||||
2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w
|
||||
) / target_w - 1
|
||||
return theta
|
||||
|
||||
|
||||
def detect(input_image: Image) -> list:
|
||||
face_detector = dlib.get_frontal_face_detector()
|
||||
detected_faces = []
|
||||
|
||||
landmark = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"shape_predictor_68_face_landmarks.dat",
|
||||
)
|
||||
landmark_locator = dlib.shape_predictor(landmark)
|
||||
|
||||
image = np.array(input_image)
|
||||
faces = face_detector(image)
|
||||
|
||||
if len(faces) == 0:
|
||||
return detected_faces
|
||||
|
||||
for face_id in range(len(faces)):
|
||||
|
||||
current_face = faces[face_id]
|
||||
face_landmarks = landmark_locator(image, current_face)
|
||||
current_fl = search(face_landmarks)
|
||||
|
||||
affine = compute_transformation_matrix(
|
||||
image, current_fl, False, target_face_scale=1.3
|
||||
)
|
||||
|
||||
aligned_face = warp(image, affine, output_shape=(256, 256, 3))
|
||||
detected_faces.append(Image.fromarray(img_as_ubyte(aligned_face)))
|
||||
|
||||
return detected_faces
|
||||
@@ -1,153 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
from skimage.transform import SimilarityTransform
|
||||
from matplotlib.patches import Rectangle
|
||||
from skimage.transform import warp
|
||||
from skimage import img_as_ubyte
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import dlib
|
||||
import os
|
||||
|
||||
|
||||
def _standard_face_pts():
|
||||
pts = (
|
||||
np.array(
|
||||
[196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4],
|
||||
np.float32,
|
||||
)
|
||||
/ 256.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
return np.reshape(pts, (5, 2))
|
||||
|
||||
|
||||
def get_landmark(face_landmarks, id):
|
||||
part = face_landmarks.part(id)
|
||||
x = part.x
|
||||
y = part.y
|
||||
|
||||
return (x, y)
|
||||
|
||||
|
||||
def search(face_landmarks):
|
||||
|
||||
x1, y1 = get_landmark(face_landmarks, 36)
|
||||
x2, y2 = get_landmark(face_landmarks, 39)
|
||||
x3, y3 = get_landmark(face_landmarks, 42)
|
||||
x4, y4 = get_landmark(face_landmarks, 45)
|
||||
|
||||
x_nose, y_nose = get_landmark(face_landmarks, 30)
|
||||
|
||||
x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48)
|
||||
x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54)
|
||||
|
||||
x_left_eye = int((x1 + x2) / 2)
|
||||
y_left_eye = int((y1 + y2) / 2)
|
||||
x_right_eye = int((x3 + x4) / 2)
|
||||
y_right_eye = int((y3 + y4) / 2)
|
||||
|
||||
results = np.array(
|
||||
[
|
||||
[x_left_eye, y_left_eye],
|
||||
[x_right_eye, y_right_eye],
|
||||
[x_nose, y_nose],
|
||||
[x_left_mouth, y_left_mouth],
|
||||
[x_right_mouth, y_right_mouth],
|
||||
]
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0):
|
||||
|
||||
std_pts = _standard_face_pts() # [-1,1]
|
||||
target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0
|
||||
|
||||
# print(target_pts)
|
||||
|
||||
h, w, c = img.shape
|
||||
if normalize == True:
|
||||
landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0
|
||||
landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0
|
||||
|
||||
# print(landmark)
|
||||
|
||||
affine = SimilarityTransform()
|
||||
|
||||
affine.estimate(target_pts, landmark)
|
||||
|
||||
return affine.params
|
||||
|
||||
|
||||
def show_detection(image, box, landmark):
|
||||
plt.imshow(image)
|
||||
print(box[2] - box[0])
|
||||
plt.gca().add_patch(
|
||||
Rectangle(
|
||||
(box[1], box[0]),
|
||||
box[2] - box[0],
|
||||
box[3] - box[1],
|
||||
linewidth=1,
|
||||
edgecolor="r",
|
||||
facecolor="none",
|
||||
)
|
||||
)
|
||||
plt.scatter(landmark[0][0], landmark[0][1])
|
||||
plt.scatter(landmark[1][0], landmark[1][1])
|
||||
plt.scatter(landmark[2][0], landmark[2][1])
|
||||
plt.scatter(landmark[3][0], landmark[3][1])
|
||||
plt.scatter(landmark[4][0], landmark[4][1])
|
||||
plt.show()
|
||||
|
||||
|
||||
def affine2theta(affine, input_w, input_h, target_w, target_h):
|
||||
# param = np.linalg.inv(affine)
|
||||
param = affine
|
||||
theta = np.zeros([2, 3])
|
||||
theta[0, 0] = param[0, 0] * input_h / target_h
|
||||
theta[0, 1] = param[0, 1] * input_w / target_h
|
||||
theta[0, 2] = (
|
||||
2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w
|
||||
) / target_h - 1
|
||||
theta[1, 0] = param[1, 0] * input_h / target_w
|
||||
theta[1, 1] = param[1, 1] * input_w / target_w
|
||||
theta[1, 2] = (
|
||||
2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w
|
||||
) / target_w - 1
|
||||
return theta
|
||||
|
||||
|
||||
def detect_hr(input_image: Image) -> list:
|
||||
face_detector = dlib.get_frontal_face_detector()
|
||||
detected_faces = []
|
||||
|
||||
landmark = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"shape_predictor_68_face_landmarks.dat",
|
||||
)
|
||||
landmark_locator = dlib.shape_predictor(landmark)
|
||||
|
||||
image = np.array(input_image)
|
||||
faces = face_detector(image)
|
||||
|
||||
if len(faces) == 0:
|
||||
return detected_faces
|
||||
|
||||
for face_id in range(len(faces)):
|
||||
|
||||
current_face = faces[face_id]
|
||||
face_landmarks = landmark_locator(image, current_face)
|
||||
current_fl = search(face_landmarks)
|
||||
|
||||
affine = compute_transformation_matrix(
|
||||
image, current_fl, False, target_face_scale=1.3
|
||||
)
|
||||
|
||||
aligned_face = warp(image, affine, output_shape=(512, 512, 3))
|
||||
detected_faces.append(Image.fromarray(img_as_ubyte(aligned_face)))
|
||||
|
||||
return detected_faces
|
||||
@@ -1,24 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
import torch.utils.data
|
||||
from .face_dataset import FaceTestDataset
|
||||
|
||||
|
||||
def create_dataloader(opt, faces: list):
|
||||
|
||||
instance = FaceTestDataset()
|
||||
instance.initialize(opt, faces)
|
||||
|
||||
print(
|
||||
f"Dataset [{type(instance).__name__}] with {len(instance)} faces was created..."
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
instance,
|
||||
batch_size=opt.batchSize,
|
||||
shuffle=not opt.serial_batches,
|
||||
num_workers=int(opt.nThreads),
|
||||
drop_last=opt.isTrain,
|
||||
)
|
||||
|
||||
return dataloader
|
||||
@@ -1,125 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class BaseDataset(data.Dataset):
|
||||
def __init__(self):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
return parser
|
||||
|
||||
def initialize(self, opt):
|
||||
pass
|
||||
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if opt.preprocess_mode == "resize_and_crop":
|
||||
new_h = new_w = opt.load_size
|
||||
elif opt.preprocess_mode == "scale_width_and_crop":
|
||||
new_w = opt.load_size
|
||||
new_h = opt.load_size * h // w
|
||||
elif opt.preprocess_mode == "scale_shortside_and_crop":
|
||||
ss, ls = min(w, h), max(w, h) # shortside and longside
|
||||
width_is_shorter = w == ss
|
||||
ls = int(opt.load_size * ls / ss)
|
||||
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
|
||||
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
||||
|
||||
flip = random.random() > 0.5
|
||||
return {"crop_pos": (x, y), "flip": flip}
|
||||
|
||||
|
||||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
|
||||
transform_list = []
|
||||
if "resize" in opt.preprocess_mode:
|
||||
osize = [opt.load_size, opt.load_size]
|
||||
transform_list.append(transforms.Resize(osize, interpolation=method))
|
||||
elif "scale_width" in opt.preprocess_mode:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
|
||||
elif "scale_shortside" in opt.preprocess_mode:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
|
||||
|
||||
if "crop" in opt.preprocess_mode:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size)))
|
||||
|
||||
if opt.preprocess_mode == "none":
|
||||
base = 32
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
||||
|
||||
if opt.preprocess_mode == "fixed":
|
||||
w = opt.crop_size
|
||||
h = round(opt.crop_size / opt.aspect_ratio)
|
||||
transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
|
||||
|
||||
if opt.isTrain and not opt.no_flip:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"])))
|
||||
|
||||
if toTensor:
|
||||
transform_list += [transforms.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
|
||||
def normalize():
|
||||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
|
||||
def __resize(img, w, h, method=Image.BICUBIC):
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __scale_width(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if ow == target_width:
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __scale_shortside(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
|
||||
width_is_shorter = ow == ss
|
||||
if ss == target_width:
|
||||
return img
|
||||
ls = int(target_width * ls / ss)
|
||||
nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
|
||||
return img.resize((nw, nh), method)
|
||||
|
||||
|
||||
def __crop(img, pos, size):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
@@ -1,56 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .pix2pix_dataset import Pix2pixDataset
|
||||
from .image_folder import make_dataset
|
||||
|
||||
|
||||
class CustomDataset(Pix2pixDataset):
|
||||
""" Dataset that loads images from directories
|
||||
Use option --label_dir, --image_dir, --instance_dir to specify the directories.
|
||||
The images in the directories are sorted in alphabetical order and paired in order.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
|
||||
parser.set_defaults(preprocess_mode="resize_and_crop")
|
||||
load_size = 286 if is_train else 256
|
||||
parser.set_defaults(load_size=load_size)
|
||||
parser.set_defaults(crop_size=256)
|
||||
parser.set_defaults(display_winsize=256)
|
||||
parser.set_defaults(label_nc=13)
|
||||
parser.set_defaults(contain_dontcare_label=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--label_dir", type=str, required=True, help="path to the directory that contains label images"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_dir", type=str, required=True, help="path to the directory that contains photo images"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="path to the directory that contains instance maps. Leave black if not exists",
|
||||
)
|
||||
return parser
|
||||
|
||||
def get_paths(self, opt):
|
||||
label_dir = opt.label_dir
|
||||
label_paths = make_dataset(label_dir, recursive=False, read_cache=True)
|
||||
|
||||
image_dir = opt.image_dir
|
||||
image_paths = make_dataset(image_dir, recursive=False, read_cache=True)
|
||||
|
||||
if len(opt.instance_dir) > 0:
|
||||
instance_dir = opt.instance_dir
|
||||
instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
|
||||
else:
|
||||
instance_paths = []
|
||||
|
||||
assert len(label_paths) == len(
|
||||
image_paths
|
||||
), "The #images in %s and %s do not match. Is there something wrong?"
|
||||
|
||||
return label_paths, image_paths, instance_paths
|
||||
@@ -1,72 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
from .base_dataset import BaseDataset, get_params, get_transform
|
||||
import torch
|
||||
|
||||
|
||||
class FaceTestDataset(BaseDataset):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser.add_argument(
|
||||
"--no_pairing_check",
|
||||
action="store_true",
|
||||
help="If specified, skip sanity check of correct label-image file pairing",
|
||||
)
|
||||
# parser.set_defaults(contain_dontcare_label=False)
|
||||
# parser.set_defaults(no_instance=True)
|
||||
return parser
|
||||
|
||||
def initialize(self, opt, faces: list):
|
||||
self.opt = opt
|
||||
self.images = faces # All the images
|
||||
|
||||
self.parts = [
|
||||
"skin",
|
||||
"hair",
|
||||
"l_brow",
|
||||
"r_brow",
|
||||
"l_eye",
|
||||
"r_eye",
|
||||
"eye_g",
|
||||
"l_ear",
|
||||
"r_ear",
|
||||
"ear_r",
|
||||
"nose",
|
||||
"mouth",
|
||||
"u_lip",
|
||||
"l_lip",
|
||||
"neck",
|
||||
"neck_l",
|
||||
"cloth",
|
||||
"hat",
|
||||
]
|
||||
|
||||
size = len(self.images)
|
||||
self.dataset_size = size
|
||||
|
||||
def __getitem__(self, index):
|
||||
params = get_params(self.opt, (-1, -1))
|
||||
|
||||
image = self.images[index].convert("RGB")
|
||||
|
||||
transform_image = get_transform(self.opt, params)
|
||||
image_tensor = transform_image(image)
|
||||
|
||||
full_label = []
|
||||
|
||||
for _ in self.parts:
|
||||
current_part = torch.zeros((self.opt.load_size, self.opt.load_size))
|
||||
full_label.append(current_part)
|
||||
|
||||
full_label_tensor = torch.stack(full_label, 0)
|
||||
|
||||
input_dict = {
|
||||
"label": full_label_tensor,
|
||||
"image": image_tensor,
|
||||
"path": "",
|
||||
}
|
||||
|
||||
return input_dict
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_size
|
||||
@@ -1,101 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
".jpg",
|
||||
".JPG",
|
||||
".jpeg",
|
||||
".JPEG",
|
||||
".png",
|
||||
".PNG",
|
||||
".ppm",
|
||||
".PPM",
|
||||
".bmp",
|
||||
".BMP",
|
||||
".tiff",
|
||||
".webp",
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset_rec(dir, images):
|
||||
assert os.path.isdir(dir), "%s is not a valid directory" % dir
|
||||
|
||||
for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
|
||||
|
||||
def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
|
||||
images = []
|
||||
|
||||
if read_cache:
|
||||
possible_filelist = os.path.join(dir, "files.list")
|
||||
if os.path.isfile(possible_filelist):
|
||||
with open(possible_filelist, "r") as f:
|
||||
images = f.read().splitlines()
|
||||
return images
|
||||
|
||||
if recursive:
|
||||
make_dataset_rec(dir, images)
|
||||
else:
|
||||
assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
|
||||
|
||||
for root, dnames, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
|
||||
if write_cache:
|
||||
filelist_cache = os.path.join(dir, "files.list")
|
||||
with open(filelist_cache, "w") as f:
|
||||
for path in images:
|
||||
f.write("%s\n" % path)
|
||||
print("wrote filelist cache at %s" % filelist_cache)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert("RGB")
|
||||
|
||||
|
||||
class ImageFolder(data.Dataset):
|
||||
def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise (
|
||||
RuntimeError(
|
||||
"Found 0 images in: " + root + "\n"
|
||||
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
|
||||
)
|
||||
)
|
||||
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.return_paths = return_paths
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.imgs[index]
|
||||
img = self.loader(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.return_paths:
|
||||
return img, path
|
||||
else:
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
@@ -1,108 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .base_dataset import BaseDataset, get_params, get_transform
|
||||
from PIL import Image
|
||||
from ..util import util
|
||||
import os
|
||||
|
||||
|
||||
class Pix2pixDataset(BaseDataset):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser.add_argument(
|
||||
"--no_pairing_check",
|
||||
action="store_true",
|
||||
help="If specified, skip sanity check of correct label-image file pairing",
|
||||
)
|
||||
return parser
|
||||
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
|
||||
label_paths, image_paths, instance_paths = self.get_paths(opt)
|
||||
|
||||
util.natural_sort(label_paths)
|
||||
util.natural_sort(image_paths)
|
||||
if not opt.no_instance:
|
||||
util.natural_sort(instance_paths)
|
||||
|
||||
label_paths = label_paths[: opt.max_dataset_size]
|
||||
image_paths = image_paths[: opt.max_dataset_size]
|
||||
instance_paths = instance_paths[: opt.max_dataset_size]
|
||||
|
||||
if not opt.no_pairing_check:
|
||||
for path1, path2 in zip(label_paths, image_paths):
|
||||
assert self.paths_match(path1, path2), (
|
||||
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this."
|
||||
% (path1, path2)
|
||||
)
|
||||
|
||||
self.label_paths = label_paths
|
||||
self.image_paths = image_paths
|
||||
self.instance_paths = instance_paths
|
||||
|
||||
size = len(self.label_paths)
|
||||
self.dataset_size = size
|
||||
|
||||
def get_paths(self, opt):
|
||||
label_paths = []
|
||||
image_paths = []
|
||||
instance_paths = []
|
||||
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
|
||||
return label_paths, image_paths, instance_paths
|
||||
|
||||
def paths_match(self, path1, path2):
|
||||
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
|
||||
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
|
||||
return filename1_without_ext == filename2_without_ext
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Label Image
|
||||
label_path = self.label_paths[index]
|
||||
label = Image.open(label_path)
|
||||
params = get_params(self.opt, label.size)
|
||||
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
||||
label_tensor = transform_label(label) * 255.0
|
||||
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
|
||||
|
||||
# input image (real images)
|
||||
image_path = self.image_paths[index]
|
||||
assert self.paths_match(
|
||||
label_path, image_path
|
||||
), "The label_path %s and image_path %s don't match." % (label_path, image_path)
|
||||
image = Image.open(image_path)
|
||||
image = image.convert("RGB")
|
||||
|
||||
transform_image = get_transform(self.opt, params)
|
||||
image_tensor = transform_image(image)
|
||||
|
||||
# if using instance maps
|
||||
if self.opt.no_instance:
|
||||
instance_tensor = 0
|
||||
else:
|
||||
instance_path = self.instance_paths[index]
|
||||
instance = Image.open(instance_path)
|
||||
if instance.mode == "L":
|
||||
instance_tensor = transform_label(instance) * 255
|
||||
instance_tensor = instance_tensor.long()
|
||||
else:
|
||||
instance_tensor = transform_label(instance)
|
||||
|
||||
input_dict = {
|
||||
"label": label_tensor,
|
||||
"instance": instance_tensor,
|
||||
"image": image_tensor,
|
||||
"path": image_path,
|
||||
}
|
||||
|
||||
# Give subclasses a chance to modify the final output
|
||||
self.postprocess(input_dict)
|
||||
|
||||
return input_dict
|
||||
|
||||
def postprocess(self, input_dict):
|
||||
return input_dict
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_size
|
||||
@@ -1,45 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import importlib
|
||||
import torch
|
||||
|
||||
|
||||
def find_model_using_name(model_name):
|
||||
# Given the option --model [modelname],
|
||||
# the file "models/modelname_model.py"
|
||||
# will be imported.
|
||||
assert model_name == 'pix2pix'
|
||||
model_filename = f"{model_name}_model"
|
||||
|
||||
from . import pix2pix_model as modellib
|
||||
|
||||
# In the file, the class called ModelNameModel() will
|
||||
# be instantiated. It has to be a subclass of torch.nn.Module,
|
||||
# and it is case-insensitive.
|
||||
model = None
|
||||
target_model_name = model_name.replace("_", "") + "model"
|
||||
for name, cls in modellib.__dict__.items():
|
||||
if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module):
|
||||
model = cls
|
||||
|
||||
if model is None:
|
||||
raise SystemError(
|
||||
"In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase."
|
||||
% (model_filename, target_model_name)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_option_setter(model_name):
|
||||
model_class = find_model_using_name(model_name)
|
||||
return model_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
model = find_model_using_name(opt.model)
|
||||
instance = model(opt)
|
||||
print("model [%s] was created" % (type(instance).__name__))
|
||||
|
||||
return instance
|
||||
@@ -1,57 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
from .base_network import BaseNetwork
|
||||
from .generator import *
|
||||
from .encoder import *
|
||||
from ...util import util
|
||||
|
||||
|
||||
def find_network_using_name(target_network_name, filename):
|
||||
target_class_name = target_network_name + filename
|
||||
module_name = "models.networks." + filename
|
||||
network = util.find_class_in_module(target_class_name, module_name)
|
||||
|
||||
assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def modify_commandline_options(parser, is_train):
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
netG_cls = find_network_using_name(opt.netG, "generator")
|
||||
parser = netG_cls.modify_commandline_options(parser, is_train)
|
||||
if is_train:
|
||||
netD_cls = find_network_using_name(opt.netD, "discriminator")
|
||||
parser = netD_cls.modify_commandline_options(parser, is_train)
|
||||
netE_cls = find_network_using_name("conv", "encoder")
|
||||
parser = netE_cls.modify_commandline_options(parser, is_train)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def create_network(cls, opt):
|
||||
net = cls(opt)
|
||||
net.print_network()
|
||||
if torch.cuda.is_available() and len(opt.gpu_ids) > 0:
|
||||
net.cuda()
|
||||
net.init_weights(opt.init_type, opt.init_variance)
|
||||
return net
|
||||
|
||||
|
||||
def define_G(opt):
|
||||
netG_cls = find_network_using_name(opt.netG, "generator")
|
||||
return create_network(netG_cls, opt)
|
||||
|
||||
|
||||
def define_D(opt):
|
||||
netD_cls = find_network_using_name(opt.netD, "discriminator")
|
||||
return create_network(netD_cls, opt)
|
||||
|
||||
|
||||
def define_E(opt):
|
||||
# there exists only one encoder type
|
||||
netE_cls = find_network_using_name("conv", "encoder")
|
||||
return create_network(netE_cls, opt)
|
||||
@@ -1,173 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import torch.nn.utils.spectral_norm as spectral_norm
|
||||
from .normalization import SPADE
|
||||
|
||||
|
||||
# ResNet block that uses SPADE.
|
||||
# It differs from the ResNet block of pix2pixHD in that
|
||||
# it takes in the segmentation map as input, learns the skip connection if necessary,
|
||||
# and applies normalization first and then convolution.
|
||||
# This architecture seemed like a standard architecture for unconditional or
|
||||
# class-conditional GAN architecture using residual block.
|
||||
# The code was inspired from https://github.com/LMescheder/GAN_stability.
|
||||
class SPADEResnetBlock(nn.Module):
|
||||
def __init__(self, fin, fout, opt):
|
||||
super().__init__()
|
||||
# Attributes
|
||||
self.learned_shortcut = fin != fout
|
||||
fmiddle = min(fin, fout)
|
||||
|
||||
self.opt = opt
|
||||
# create conv layers
|
||||
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
||||
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
||||
if self.learned_shortcut:
|
||||
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
||||
|
||||
# apply spectral norm if specified
|
||||
if "spectral" in opt.norm_G:
|
||||
self.conv_0 = spectral_norm(self.conv_0)
|
||||
self.conv_1 = spectral_norm(self.conv_1)
|
||||
if self.learned_shortcut:
|
||||
self.conv_s = spectral_norm(self.conv_s)
|
||||
|
||||
# define normalization layers
|
||||
spade_config_str = opt.norm_G.replace("spectral", "")
|
||||
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
||||
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
|
||||
if self.learned_shortcut:
|
||||
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
||||
|
||||
# note the resnet block with SPADE also takes in |seg|,
|
||||
# the semantic segmentation map as input
|
||||
def forward(self, x, seg, degraded_image):
|
||||
x_s = self.shortcut(x, seg, degraded_image)
|
||||
|
||||
dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
|
||||
dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))
|
||||
|
||||
out = x_s + dx
|
||||
|
||||
return out
|
||||
|
||||
def shortcut(self, x, seg, degraded_image):
|
||||
if self.learned_shortcut:
|
||||
x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
|
||||
else:
|
||||
x_s = x
|
||||
return x_s
|
||||
|
||||
def actvn(self, x):
|
||||
return F.leaky_relu(x, 2e-1)
|
||||
|
||||
|
||||
# ResNet block used in pix2pixHD
|
||||
# We keep the same architecture as pix2pixHD.
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
|
||||
super().__init__()
|
||||
|
||||
pw = (kernel_size - 1) // 2
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.ReflectionPad2d(pw),
|
||||
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
|
||||
activation,
|
||||
nn.ReflectionPad2d(pw),
|
||||
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv_block(x)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
|
||||
# VGG architecter, used for the perceptual loss using a pretrained VGG network
|
||||
class VGG19(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False):
|
||||
super().__init__()
|
||||
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(2, 7):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(7, 12):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 21):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(21, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h_relu1 = self.slice1(X)
|
||||
h_relu2 = self.slice2(h_relu1)
|
||||
h_relu3 = self.slice3(h_relu2)
|
||||
h_relu4 = self.slice4(h_relu3)
|
||||
h_relu5 = self.slice5(h_relu4)
|
||||
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||
return out
|
||||
|
||||
|
||||
class SPADEResnetBlock_non_spade(nn.Module):
|
||||
def __init__(self, fin, fout, opt):
|
||||
super().__init__()
|
||||
# Attributes
|
||||
self.learned_shortcut = fin != fout
|
||||
fmiddle = min(fin, fout)
|
||||
|
||||
self.opt = opt
|
||||
# create conv layers
|
||||
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
||||
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
||||
if self.learned_shortcut:
|
||||
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
||||
|
||||
# apply spectral norm if specified
|
||||
if "spectral" in opt.norm_G:
|
||||
self.conv_0 = spectral_norm(self.conv_0)
|
||||
self.conv_1 = spectral_norm(self.conv_1)
|
||||
if self.learned_shortcut:
|
||||
self.conv_s = spectral_norm(self.conv_s)
|
||||
|
||||
# define normalization layers
|
||||
spade_config_str = opt.norm_G.replace("spectral", "")
|
||||
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
||||
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
|
||||
if self.learned_shortcut:
|
||||
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
||||
|
||||
# note the resnet block with SPADE also takes in |seg|,
|
||||
# the semantic segmentation map as input
|
||||
def forward(self, x, seg, degraded_image):
|
||||
x_s = self.shortcut(x, seg, degraded_image)
|
||||
|
||||
dx = self.conv_0(self.actvn(x))
|
||||
dx = self.conv_1(self.actvn(dx))
|
||||
|
||||
out = x_s + dx
|
||||
|
||||
return out
|
||||
|
||||
def shortcut(self, x, seg, degraded_image):
|
||||
if self.learned_shortcut:
|
||||
x_s = self.conv_s(x)
|
||||
else:
|
||||
x_s = x
|
||||
return x_s
|
||||
|
||||
def actvn(self, x):
|
||||
return F.leaky_relu(x, 2e-1)
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
|
||||
|
||||
class BaseNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(BaseNetwork, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
return parser
|
||||
|
||||
def print_network(self):
|
||||
if isinstance(self, list):
|
||||
self = self[0]
|
||||
num_params = 0
|
||||
for param in self.parameters():
|
||||
num_params += param.numel()
|
||||
print(
|
||||
"Network [%s] was created. Total number of parameters: %.1f million."
|
||||
% (type(self).__name__, num_params / 1000000)
|
||||
)
|
||||
|
||||
def init_weights(self, init_type="normal", gain=0.02):
|
||||
def init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("BatchNorm2d") != -1:
|
||||
if hasattr(m, "weight") and m.weight is not None:
|
||||
init.normal_(m.weight.data, 1.0, gain)
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
|
||||
if init_type == "normal":
|
||||
init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == "xavier":
|
||||
init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == "xavier_uniform":
|
||||
init.xavier_uniform_(m.weight.data, gain=1.0)
|
||||
elif init_type == "kaiming":
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
||||
elif init_type == "orthogonal":
|
||||
init.orthogonal_(m.weight.data, gain=gain)
|
||||
elif init_type == "none": # uses pytorch's default init method
|
||||
m.reset_parameters()
|
||||
else:
|
||||
raise NotImplementedError("initialization method [%s] is not implemented" % init_type)
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
self.apply(init_func)
|
||||
|
||||
# propagate to children
|
||||
for m in self.children():
|
||||
if hasattr(m, "init_weights"):
|
||||
m.init_weights(init_type, gain)
|
||||
@@ -1,53 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from .base_network import BaseNetwork
|
||||
from .normalization import get_nonspade_norm_layer
|
||||
|
||||
|
||||
class ConvEncoder(BaseNetwork):
|
||||
""" Same architecture as the image discriminator """
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
|
||||
kw = 3
|
||||
pw = int(np.ceil((kw - 1.0) / 2))
|
||||
ndf = opt.ngf
|
||||
norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
|
||||
self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw))
|
||||
self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw))
|
||||
self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw))
|
||||
self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw))
|
||||
self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
|
||||
if opt.crop_size >= 256:
|
||||
self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw))
|
||||
|
||||
self.so = s0 = 4
|
||||
self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256)
|
||||
self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256)
|
||||
|
||||
self.actvn = nn.LeakyReLU(0.2, False)
|
||||
self.opt = opt
|
||||
|
||||
def forward(self, x):
|
||||
if x.size(2) != 256 or x.size(3) != 256:
|
||||
x = F.interpolate(x, size=(256, 256), mode="bilinear")
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(self.actvn(x))
|
||||
x = self.layer3(self.actvn(x))
|
||||
x = self.layer4(self.actvn(x))
|
||||
x = self.layer5(self.actvn(x))
|
||||
if self.opt.crop_size >= 256:
|
||||
x = self.layer6(self.actvn(x))
|
||||
x = self.actvn(x)
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
mu = self.fc_mu(x)
|
||||
logvar = self.fc_var(x)
|
||||
|
||||
return mu, logvar
|
||||
@@ -1,233 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .base_network import BaseNetwork
|
||||
from .normalization import get_nonspade_norm_layer
|
||||
from .architecture import ResnetBlock as ResnetBlock
|
||||
from .architecture import SPADEResnetBlock as SPADEResnetBlock
|
||||
from .architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade
|
||||
|
||||
|
||||
class SPADEGenerator(BaseNetwork):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser.set_defaults(norm_G="spectralspadesyncbatch3x3")
|
||||
parser.add_argument(
|
||||
"--num_upsampling_layers",
|
||||
choices=("normal", "more", "most"),
|
||||
default="normal",
|
||||
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
nf = opt.ngf
|
||||
|
||||
self.sw, self.sh = self.compute_latent_vector_size(opt)
|
||||
|
||||
print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))
|
||||
|
||||
if opt.use_vae:
|
||||
# In case of VAE, we will sample from random z vector
|
||||
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
|
||||
else:
|
||||
# Otherwise, we make the network deterministic by starting with
|
||||
# downsampled segmentation map instead of random z
|
||||
if self.opt.no_parsing_map:
|
||||
self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
|
||||
else:
|
||||
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
|
||||
|
||||
if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
|
||||
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
||||
else:
|
||||
self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
||||
|
||||
if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
|
||||
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
||||
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
||||
|
||||
else:
|
||||
self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
||||
self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
||||
|
||||
if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
|
||||
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
|
||||
else:
|
||||
self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)
|
||||
|
||||
if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
|
||||
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
|
||||
else:
|
||||
self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)
|
||||
|
||||
if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
|
||||
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
|
||||
else:
|
||||
self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)
|
||||
|
||||
if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
|
||||
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
|
||||
else:
|
||||
self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)
|
||||
|
||||
final_nc = nf
|
||||
|
||||
if opt.num_upsampling_layers == "most":
|
||||
self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
|
||||
final_nc = nf // 2
|
||||
|
||||
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
|
||||
|
||||
self.up = nn.Upsample(scale_factor=2)
|
||||
|
||||
def compute_latent_vector_size(self, opt):
|
||||
if opt.num_upsampling_layers == "normal":
|
||||
num_up_layers = 5
|
||||
elif opt.num_upsampling_layers == "more":
|
||||
num_up_layers = 6
|
||||
elif opt.num_upsampling_layers == "most":
|
||||
num_up_layers = 7
|
||||
else:
|
||||
raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)
|
||||
|
||||
sw = opt.load_size // (2 ** num_up_layers)
|
||||
sh = round(sw / opt.aspect_ratio)
|
||||
|
||||
return sw, sh
|
||||
|
||||
def forward(self, input, degraded_image, z=None):
|
||||
seg = input
|
||||
|
||||
if self.opt.use_vae:
|
||||
# we sample z from unit normal and reshape the tensor
|
||||
if z is None:
|
||||
z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())
|
||||
x = self.fc(z)
|
||||
x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
|
||||
else:
|
||||
# we downsample segmap and run convolution
|
||||
if self.opt.no_parsing_map:
|
||||
x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear")
|
||||
else:
|
||||
x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest")
|
||||
x = self.fc(x)
|
||||
|
||||
x = self.head_0(x, seg, degraded_image)
|
||||
|
||||
x = self.up(x)
|
||||
x = self.G_middle_0(x, seg, degraded_image)
|
||||
|
||||
if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most":
|
||||
x = self.up(x)
|
||||
|
||||
x = self.G_middle_1(x, seg, degraded_image)
|
||||
|
||||
x = self.up(x)
|
||||
x = self.up_0(x, seg, degraded_image)
|
||||
x = self.up(x)
|
||||
x = self.up_1(x, seg, degraded_image)
|
||||
x = self.up(x)
|
||||
x = self.up_2(x, seg, degraded_image)
|
||||
x = self.up(x)
|
||||
x = self.up_3(x, seg, degraded_image)
|
||||
|
||||
if self.opt.num_upsampling_layers == "most":
|
||||
x = self.up(x)
|
||||
x = self.up_4(x, seg, degraded_image)
|
||||
|
||||
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
||||
x = F.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Pix2PixHDGenerator(BaseNetwork):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
parser.add_argument(
|
||||
"--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resnet_n_blocks",
|
||||
type=int,
|
||||
default=9,
|
||||
help="number of residual blocks in the global generator network",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution"
|
||||
)
|
||||
# parser.set_defaults(norm_G='instance')
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
input_nc = 3
|
||||
|
||||
# print("xxxxx")
|
||||
# print(opt.norm_G)
|
||||
norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
|
||||
activation = nn.ReLU(False)
|
||||
|
||||
model = []
|
||||
|
||||
# initial conv
|
||||
model += [
|
||||
nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
|
||||
norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),
|
||||
activation,
|
||||
]
|
||||
|
||||
# downsample
|
||||
mult = 1
|
||||
for i in range(opt.resnet_n_downsample):
|
||||
model += [
|
||||
norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
|
||||
activation,
|
||||
]
|
||||
mult *= 2
|
||||
|
||||
# resnet blocks
|
||||
for i in range(opt.resnet_n_blocks):
|
||||
model += [
|
||||
ResnetBlock(
|
||||
opt.ngf * mult,
|
||||
norm_layer=norm_layer,
|
||||
activation=activation,
|
||||
kernel_size=opt.resnet_kernel_size,
|
||||
)
|
||||
]
|
||||
|
||||
# upsample
|
||||
for i in range(opt.resnet_n_downsample):
|
||||
nc_in = int(opt.ngf * mult)
|
||||
nc_out = int((opt.ngf * mult) / 2)
|
||||
model += [
|
||||
norm_layer(
|
||||
nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)
|
||||
),
|
||||
activation,
|
||||
]
|
||||
mult = mult // 2
|
||||
|
||||
# final output conv
|
||||
model += [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input, degraded_image, z=None):
|
||||
return self.model(degraded_image)
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .sync_batchnorm import SynchronizedBatchNorm2d
|
||||
import torch.nn.utils.spectral_norm as spectral_norm
|
||||
|
||||
|
||||
def get_nonspade_norm_layer(opt, norm_type="instance"):
|
||||
# helper function to get # output channels of the previous layer
|
||||
def get_out_channel(layer):
|
||||
if hasattr(layer, "out_channels"):
|
||||
return getattr(layer, "out_channels")
|
||||
return layer.weight.size(0)
|
||||
|
||||
# this function will be returned
|
||||
def add_norm_layer(layer):
|
||||
nonlocal norm_type
|
||||
if norm_type.startswith("spectral"):
|
||||
layer = spectral_norm(layer)
|
||||
subnorm_type = norm_type[len("spectral") :]
|
||||
|
||||
if subnorm_type == "none" or len(subnorm_type) == 0:
|
||||
return layer
|
||||
|
||||
# remove bias in the previous layer, which is meaningless
|
||||
# since it has no effect after normalization
|
||||
if getattr(layer, "bias", None) is not None:
|
||||
delattr(layer, "bias")
|
||||
layer.register_parameter("bias", None)
|
||||
|
||||
if subnorm_type == "batch":
|
||||
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
||||
elif subnorm_type == "sync_batch":
|
||||
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
|
||||
elif subnorm_type == "instance":
|
||||
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
||||
else:
|
||||
raise ValueError("normalization layer %s is not recognized" % subnorm_type)
|
||||
|
||||
return nn.Sequential(layer, norm_layer)
|
||||
|
||||
return add_norm_layer
|
||||
|
||||
|
||||
class SPADE(nn.Module):
|
||||
def __init__(self, config_text, norm_nc, label_nc, opt):
|
||||
super().__init__()
|
||||
|
||||
assert config_text.startswith("spade")
|
||||
parsed = re.search("spade(\D+)(\d)x\d", config_text)
|
||||
param_free_norm_type = str(parsed.group(1))
|
||||
ks = int(parsed.group(2))
|
||||
self.opt = opt
|
||||
if param_free_norm_type == "instance":
|
||||
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
||||
elif param_free_norm_type == "syncbatch":
|
||||
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
|
||||
elif param_free_norm_type == "batch":
|
||||
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
||||
else:
|
||||
raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type)
|
||||
|
||||
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
||||
nhidden = 128
|
||||
|
||||
pw = ks // 2
|
||||
|
||||
if self.opt.no_parsing_map:
|
||||
self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
|
||||
else:
|
||||
self.mlp_shared = nn.Sequential(
|
||||
nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()
|
||||
)
|
||||
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
||||
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
||||
|
||||
def forward(self, x, segmap, degraded_image):
|
||||
|
||||
# Part 1. generate parameter-free normalized activations
|
||||
normalized = self.param_free_norm(x)
|
||||
|
||||
# Part 2. produce scaling and bias conditioned on semantic map
|
||||
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
|
||||
degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear")
|
||||
|
||||
if self.opt.no_parsing_map:
|
||||
actv = self.mlp_shared(degraded_face)
|
||||
else:
|
||||
actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1))
|
||||
gamma = self.mlp_gamma(actv)
|
||||
beta = self.mlp_beta(actv)
|
||||
|
||||
# apply scale and bias
|
||||
out = normalized * (1 + gamma) + beta
|
||||
|
||||
return out
|
||||
@@ -1,14 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : __init__.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
from .batchnorm import set_sbn_eps_mode
|
||||
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
||||
from .batchnorm import patch_sync_batchnorm, convert_model
|
||||
from .replicate import DataParallelWithCallback, patch_replication_callback
|
||||
@@ -1,412 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
try:
|
||||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||
except ImportError:
|
||||
ReduceAddCoalesced = Broadcast = None
|
||||
|
||||
try:
|
||||
from jactorch.parallel.comm import SyncMaster
|
||||
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
|
||||
except ImportError:
|
||||
from .comm import SyncMaster
|
||||
from .replicate import DataParallelWithCallback
|
||||
|
||||
__all__ = [
|
||||
'set_sbn_eps_mode',
|
||||
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
|
||||
'patch_sync_batchnorm', 'convert_model'
|
||||
]
|
||||
|
||||
|
||||
SBN_EPS_MODE = 'clamp'
|
||||
|
||||
|
||||
def set_sbn_eps_mode(mode):
|
||||
global SBN_EPS_MODE
|
||||
assert mode in ('clamp', 'plus')
|
||||
SBN_EPS_MODE = mode
|
||||
|
||||
|
||||
def _sum_ft(tensor):
|
||||
"""sum over the first and last dimention"""
|
||||
return tensor.sum(dim=0).sum(dim=-1)
|
||||
|
||||
|
||||
def _unsqueeze_ft(tensor):
|
||||
"""add new dimensions at the front and the tail"""
|
||||
return tensor.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
|
||||
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
||||
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
||||
|
||||
|
||||
class _SynchronizedBatchNorm(_BatchNorm):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
|
||||
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
|
||||
|
||||
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
|
||||
track_running_stats=track_running_stats)
|
||||
|
||||
if not self.track_running_stats:
|
||||
import warnings
|
||||
warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
|
||||
|
||||
self._sync_master = SyncMaster(self._data_parallel_master)
|
||||
|
||||
self._is_parallel = False
|
||||
self._parallel_id = None
|
||||
self._slave_pipe = None
|
||||
|
||||
def forward(self, input):
|
||||
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
||||
if not (self._is_parallel and self.training):
|
||||
return F.batch_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training, self.momentum, self.eps)
|
||||
|
||||
# Resize the input to (B, C, -1).
|
||||
input_shape = input.size()
|
||||
assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
|
||||
input = input.view(input.size(0), self.num_features, -1)
|
||||
|
||||
# Compute the sum and square-sum.
|
||||
sum_size = input.size(0) * input.size(2)
|
||||
input_sum = _sum_ft(input)
|
||||
input_ssum = _sum_ft(input ** 2)
|
||||
|
||||
# Reduce-and-broadcast the statistics.
|
||||
if self._parallel_id == 0:
|
||||
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
else:
|
||||
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
|
||||
# Compute the output.
|
||||
if self.affine:
|
||||
# MJY:: Fuse the multiplication for speed.
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
||||
else:
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
||||
|
||||
# Reshape it.
|
||||
return output.view(input_shape)
|
||||
|
||||
def __data_parallel_replicate__(self, ctx, copy_id):
|
||||
self._is_parallel = True
|
||||
self._parallel_id = copy_id
|
||||
|
||||
# parallel_id == 0 means master device.
|
||||
if self._parallel_id == 0:
|
||||
ctx.sync_master = self._sync_master
|
||||
else:
|
||||
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
||||
|
||||
def _data_parallel_master(self, intermediates):
|
||||
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
||||
|
||||
# Always using same "device order" makes the ReduceAdd operation faster.
|
||||
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
||||
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
||||
|
||||
to_reduce = [i[1][:2] for i in intermediates]
|
||||
to_reduce = [j for i in to_reduce for j in i] # flatten
|
||||
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
||||
|
||||
sum_size = sum([i[1].sum_size for i in intermediates])
|
||||
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
||||
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
||||
|
||||
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
||||
|
||||
outputs = []
|
||||
for i, rec in enumerate(intermediates):
|
||||
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
||||
|
||||
return outputs
|
||||
|
||||
def _compute_mean_std(self, sum_, ssum, size):
|
||||
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
||||
also maintains the moving average on the master device."""
|
||||
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
||||
mean = sum_ / size
|
||||
sumvar = ssum - sum_ * mean
|
||||
unbias_var = sumvar / (size - 1)
|
||||
bias_var = sumvar / size
|
||||
|
||||
if hasattr(torch, 'no_grad'):
|
||||
with torch.no_grad():
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
else:
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
|
||||
if SBN_EPS_MODE == 'clamp':
|
||||
return mean, bias_var.clamp(self.eps) ** -0.5
|
||||
elif SBN_EPS_MODE == 'plus':
|
||||
return mean, (bias_var + self.eps) ** -0.5
|
||||
else:
|
||||
raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
|
||||
|
||||
|
||||
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
||||
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
||||
mini-batch.
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of size
|
||||
`batch_size x num_features [x width]`
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
||||
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
||||
of 3d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C, H, W)`
|
||||
- Output: :math:`(N, C, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
||||
of 4d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
||||
or Spatio-temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x depth x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_sync_batchnorm():
|
||||
import torch.nn as nn
|
||||
|
||||
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
||||
|
||||
nn.BatchNorm1d = SynchronizedBatchNorm1d
|
||||
nn.BatchNorm2d = SynchronizedBatchNorm2d
|
||||
nn.BatchNorm3d = SynchronizedBatchNorm3d
|
||||
|
||||
yield
|
||||
|
||||
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
|
||||
|
||||
|
||||
def convert_model(module):
|
||||
"""Traverse the input module and its child recursively
|
||||
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
||||
to SynchronizedBatchNorm*N*d
|
||||
|
||||
Args:
|
||||
module: the input module needs to be convert to SyncBN model
|
||||
|
||||
Examples:
|
||||
>>> import torch.nn as nn
|
||||
>>> import torchvision
|
||||
>>> # m is a standard pytorch model
|
||||
>>> m = torchvision.models.resnet18(True)
|
||||
>>> m = nn.DataParallel(m)
|
||||
>>> # after convert, m is using SyncBN
|
||||
>>> m = convert_model(m)
|
||||
"""
|
||||
if isinstance(module, torch.nn.DataParallel):
|
||||
mod = module.module
|
||||
mod = convert_model(mod)
|
||||
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
|
||||
return mod
|
||||
|
||||
mod = module
|
||||
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
||||
torch.nn.modules.batchnorm.BatchNorm2d,
|
||||
torch.nn.modules.batchnorm.BatchNorm3d],
|
||||
[SynchronizedBatchNorm1d,
|
||||
SynchronizedBatchNorm2d,
|
||||
SynchronizedBatchNorm3d]):
|
||||
if isinstance(module, pth_module):
|
||||
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
||||
mod.running_mean = module.running_mean
|
||||
mod.running_var = module.running_var
|
||||
if module.affine:
|
||||
mod.weight.data = module.weight.data.clone().detach()
|
||||
mod.bias.data = module.bias.data.clone().detach()
|
||||
|
||||
for name, child in module.named_children():
|
||||
mod.add_module(name, convert_model(child))
|
||||
|
||||
return mod
|
||||
@@ -1,74 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm_reimpl.py
|
||||
# Author : acgtyrant
|
||||
# Date : 11/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
__all__ = ['BatchNorm2dReimpl']
|
||||
|
||||
|
||||
class BatchNorm2dReimpl(nn.Module):
|
||||
"""
|
||||
A re-implementation of batch normalization, used for testing the numerical
|
||||
stability.
|
||||
|
||||
Author: acgtyrant
|
||||
See also:
|
||||
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
||||
"""
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = nn.Parameter(torch.empty(num_features))
|
||||
self.bias = nn.Parameter(torch.empty(num_features))
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_running_stats(self):
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.reset_running_stats()
|
||||
init.uniform_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input_):
|
||||
batchsize, channels, height, width = input_.size()
|
||||
numel = batchsize * height * width
|
||||
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
||||
sum_ = input_.sum(1)
|
||||
sum_of_square = input_.pow(2).sum(1)
|
||||
mean = sum_ / numel
|
||||
sumvar = sum_of_square - sum_ * mean
|
||||
|
||||
self.running_mean = (
|
||||
(1 - self.momentum) * self.running_mean
|
||||
+ self.momentum * mean.detach()
|
||||
)
|
||||
unbias_var = sumvar / (numel - 1)
|
||||
self.running_var = (
|
||||
(1 - self.momentum) * self.running_var
|
||||
+ self.momentum * unbias_var.detach()
|
||||
)
|
||||
|
||||
bias_var = sumvar / numel
|
||||
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
||||
output = (
|
||||
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
||||
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
||||
|
||||
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : comm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import queue
|
||||
import collections
|
||||
import threading
|
||||
|
||||
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
||||
|
||||
|
||||
class FutureResult(object):
|
||||
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
||||
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
self._lock = threading.Lock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
|
||||
def put(self, result):
|
||||
with self._lock:
|
||||
assert self._result is None, 'Previous result has\'t been fetched.'
|
||||
self._result = result
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
if self._result is None:
|
||||
self._cond.wait()
|
||||
|
||||
res = self._result
|
||||
self._result = None
|
||||
return res
|
||||
|
||||
|
||||
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
||||
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
||||
|
||||
|
||||
class SlavePipe(_SlavePipeBase):
|
||||
"""Pipe for master-slave communication."""
|
||||
|
||||
def run_slave(self, msg):
|
||||
self.queue.put((self.identifier, msg))
|
||||
ret = self.result.get()
|
||||
self.queue.put(True)
|
||||
return ret
|
||||
|
||||
|
||||
class SyncMaster(object):
|
||||
"""An abstract `SyncMaster` object.
|
||||
|
||||
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
||||
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
||||
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
||||
and passed to a registered callback.
|
||||
- After receiving the messages, the master device should gather the information and determine to message passed
|
||||
back to each slave devices.
|
||||
"""
|
||||
|
||||
def __init__(self, master_callback):
|
||||
"""
|
||||
|
||||
Args:
|
||||
master_callback: a callback to be invoked after having collected messages from slave devices.
|
||||
"""
|
||||
self._master_callback = master_callback
|
||||
self._queue = queue.Queue()
|
||||
self._registry = collections.OrderedDict()
|
||||
self._activated = False
|
||||
|
||||
def __getstate__(self):
|
||||
return {'master_callback': self._master_callback}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(state['master_callback'])
|
||||
|
||||
def register_slave(self, identifier):
|
||||
"""
|
||||
Register an slave device.
|
||||
|
||||
Args:
|
||||
identifier: an identifier, usually is the device id.
|
||||
|
||||
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
||||
|
||||
"""
|
||||
if self._activated:
|
||||
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
||||
self._activated = False
|
||||
self._registry.clear()
|
||||
future = FutureResult()
|
||||
self._registry[identifier] = _MasterRegistry(future)
|
||||
return SlavePipe(identifier, self._queue, future)
|
||||
|
||||
def run_master(self, master_msg):
|
||||
"""
|
||||
Main entry for the master device in each forward pass.
|
||||
The messages were first collected from each devices (including the master device), and then
|
||||
an callback will be invoked to compute the message to be sent back to each devices
|
||||
(including the master device).
|
||||
|
||||
Args:
|
||||
master_msg: the message that the master want to send to itself. This will be placed as the first
|
||||
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
||||
|
||||
Returns: the message to be sent back to the master device.
|
||||
|
||||
"""
|
||||
self._activated = True
|
||||
|
||||
intermediates = [(0, master_msg)]
|
||||
for i in range(self.nr_slaves):
|
||||
intermediates.append(self._queue.get())
|
||||
|
||||
results = self._master_callback(intermediates)
|
||||
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
||||
|
||||
for i, res in results:
|
||||
if i == 0:
|
||||
continue
|
||||
self._registry[i].result.put(res)
|
||||
|
||||
for i in range(self.nr_slaves):
|
||||
assert self._queue.get() is True
|
||||
|
||||
return results[0][1]
|
||||
|
||||
@property
|
||||
def nr_slaves(self):
|
||||
return len(self._registry)
|
||||
@@ -1,94 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : replicate.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import functools
|
||||
|
||||
from torch.nn.parallel.data_parallel import DataParallel
|
||||
|
||||
__all__ = [
|
||||
'CallbackContext',
|
||||
'execute_replication_callbacks',
|
||||
'DataParallelWithCallback',
|
||||
'patch_replication_callback'
|
||||
]
|
||||
|
||||
|
||||
class CallbackContext(object):
|
||||
pass
|
||||
|
||||
|
||||
def execute_replication_callbacks(modules):
|
||||
"""
|
||||
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
||||
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
||||
(shared among multiple copies of this module on different devices).
|
||||
Through this context, different copies can share some information.
|
||||
|
||||
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
||||
of any slave copies.
|
||||
"""
|
||||
master_copy = modules[0]
|
||||
nr_modules = len(list(master_copy.modules()))
|
||||
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
||||
|
||||
for i, module in enumerate(modules):
|
||||
for j, m in enumerate(module.modules()):
|
||||
if hasattr(m, '__data_parallel_replicate__'):
|
||||
m.__data_parallel_replicate__(ctxs[j], i)
|
||||
|
||||
|
||||
class DataParallelWithCallback(DataParallel):
|
||||
"""
|
||||
Data Parallel with a replication callback.
|
||||
|
||||
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
||||
original `replicate` function.
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
# sync_bn.__data_parallel_replicate__ will be invoked.
|
||||
"""
|
||||
|
||||
def replicate(self, module, device_ids):
|
||||
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
|
||||
def patch_replication_callback(data_parallel):
|
||||
"""
|
||||
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
||||
Useful when you have customized `DataParallel` implementation.
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
||||
> patch_replication_callback(sync_bn)
|
||||
# this is equivalent to
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
"""
|
||||
|
||||
assert isinstance(data_parallel, DataParallel)
|
||||
|
||||
old_replicate = data_parallel.replicate
|
||||
|
||||
@functools.wraps(old_replicate)
|
||||
def new_replicate(module, device_ids):
|
||||
modules = old_replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
data_parallel.replicate = new_replicate
|
||||
@@ -1,29 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : unittest.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
|
||||
class TorchTestCase(unittest.TestCase):
|
||||
def assertTensorClose(self, x, y):
|
||||
adiff = float((x - y).abs().max())
|
||||
if (y == 0).all():
|
||||
rdiff = 'NaN'
|
||||
else:
|
||||
rdiff = float((adiff / y).abs().max())
|
||||
|
||||
message = (
|
||||
'Tensor close check failed\n'
|
||||
'adiff={}\n'
|
||||
'rdiff={}\n'
|
||||
).format(adiff, rdiff)
|
||||
self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
from . import networks
|
||||
from ..util import util
|
||||
|
||||
|
||||
class Pix2PixModel(torch.nn.Module):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
networks.modify_commandline_options(parser, is_train)
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
|
||||
self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor
|
||||
|
||||
self.netG, self.netD, self.netE = self.initialize_networks(opt)
|
||||
|
||||
# set loss functions
|
||||
if opt.isTrain:
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
|
||||
self.criterionFeat = torch.nn.L1Loss()
|
||||
if not opt.no_vgg_loss:
|
||||
self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
|
||||
if opt.use_vae:
|
||||
self.KLDLoss = networks.KLDLoss()
|
||||
|
||||
# Entry point for all calls involving forward pass
|
||||
# of deep networks. We used this approach since DataParallel module
|
||||
# can't parallelize custom functions, we branch to different
|
||||
# routines based on |mode|.
|
||||
def forward(self, data, mode):
|
||||
input_semantics, real_image, degraded_image = self.preprocess_input(data)
|
||||
|
||||
if mode == "generator":
|
||||
g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image)
|
||||
return g_loss, generated
|
||||
elif mode == "discriminator":
|
||||
d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image)
|
||||
return d_loss
|
||||
elif mode == "encode_only":
|
||||
z, mu, logvar = self.encode_z(real_image)
|
||||
return mu, logvar
|
||||
elif mode == "inference":
|
||||
with torch.no_grad():
|
||||
fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
|
||||
return fake_image
|
||||
else:
|
||||
raise ValueError("|mode| is invalid")
|
||||
|
||||
def create_optimizers(self, opt):
|
||||
G_params = list(self.netG.parameters())
|
||||
if opt.use_vae:
|
||||
G_params += list(self.netE.parameters())
|
||||
if opt.isTrain:
|
||||
D_params = list(self.netD.parameters())
|
||||
|
||||
beta1, beta2 = opt.beta1, opt.beta2
|
||||
if opt.no_TTUR:
|
||||
G_lr, D_lr = opt.lr, opt.lr
|
||||
else:
|
||||
G_lr, D_lr = opt.lr / 2, opt.lr * 2
|
||||
|
||||
optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
|
||||
optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
|
||||
|
||||
return optimizer_G, optimizer_D
|
||||
|
||||
def save(self, epoch):
|
||||
util.save_network(self.netG, "G", epoch, self.opt)
|
||||
util.save_network(self.netD, "D", epoch, self.opt)
|
||||
if self.opt.use_vae:
|
||||
util.save_network(self.netE, "E", epoch, self.opt)
|
||||
|
||||
############################################################################
|
||||
# Private helper methods
|
||||
############################################################################
|
||||
|
||||
def initialize_networks(self, opt):
|
||||
netG = networks.define_G(opt)
|
||||
netD = networks.define_D(opt) if opt.isTrain else None
|
||||
netE = networks.define_E(opt) if opt.use_vae else None
|
||||
|
||||
if not opt.isTrain or opt.continue_train:
|
||||
netG = util.load_network(netG, "G", opt.which_epoch, opt)
|
||||
if opt.isTrain:
|
||||
netD = util.load_network(netD, "D", opt.which_epoch, opt)
|
||||
if opt.use_vae:
|
||||
netE = util.load_network(netE, "E", opt.which_epoch, opt)
|
||||
|
||||
return netG, netD, netE
|
||||
|
||||
# preprocess the input, such as moving the tensors to GPUs and
|
||||
# transforming the label map to one-hot encoding
|
||||
# |data|: dictionary of the input data
|
||||
|
||||
def preprocess_input(self, data):
|
||||
# move to GPU and change data types
|
||||
# data['label'] = data['label'].long()
|
||||
|
||||
if not self.opt.isTrain:
|
||||
if self.use_gpu():
|
||||
data["label"] = data["label"].cuda()
|
||||
data["image"] = data["image"].cuda()
|
||||
return data["label"], data["image"], data["image"]
|
||||
|
||||
## While testing, the input image is the degraded face
|
||||
if self.use_gpu():
|
||||
data["label"] = data["label"].cuda()
|
||||
data["degraded_image"] = data["degraded_image"].cuda()
|
||||
data["image"] = data["image"].cuda()
|
||||
|
||||
# # create one-hot label map
|
||||
# label_map = data['label']
|
||||
# bs, _, h, w = label_map.size()
|
||||
# nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
|
||||
# else self.opt.label_nc
|
||||
# input_label = self.FloatTensor(bs, nc, h, w).zero_()
|
||||
# input_semantics = input_label.scatter_(1, label_map, 1.0)
|
||||
|
||||
return data["label"], data["image"], data["degraded_image"]
|
||||
|
||||
def compute_generator_loss(self, input_semantics, degraded_image, real_image):
|
||||
G_losses = {}
|
||||
|
||||
fake_image, KLD_loss = self.generate_fake(
|
||||
input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae
|
||||
)
|
||||
|
||||
if self.opt.use_vae:
|
||||
G_losses["KLD"] = KLD_loss
|
||||
|
||||
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
|
||||
|
||||
G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False)
|
||||
|
||||
if not self.opt.no_ganFeat_loss:
|
||||
num_D = len(pred_fake)
|
||||
GAN_Feat_loss = self.FloatTensor(1).fill_(0)
|
||||
for i in range(num_D): # for each discriminator
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs): # for each layer output
|
||||
unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
|
||||
GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
|
||||
G_losses["GAN_Feat"] = GAN_Feat_loss
|
||||
|
||||
if not self.opt.no_vgg_loss:
|
||||
G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg
|
||||
|
||||
return G_losses, fake_image
|
||||
|
||||
def compute_discriminator_loss(self, input_semantics, degraded_image, real_image):
|
||||
D_losses = {}
|
||||
with torch.no_grad():
|
||||
fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image)
|
||||
fake_image = fake_image.detach()
|
||||
fake_image.requires_grad_()
|
||||
|
||||
pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
|
||||
|
||||
D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True)
|
||||
D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True)
|
||||
|
||||
return D_losses
|
||||
|
||||
def encode_z(self, real_image):
|
||||
mu, logvar = self.netE(real_image)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
|
||||
def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):
|
||||
z = None
|
||||
KLD_loss = None
|
||||
if self.opt.use_vae:
|
||||
z, mu, logvar = self.encode_z(real_image)
|
||||
if compute_kld_loss:
|
||||
KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
|
||||
|
||||
fake_image = self.netG(input_semantics, degraded_image, z=z)
|
||||
|
||||
assert (
|
||||
not compute_kld_loss
|
||||
) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
|
||||
|
||||
return fake_image, KLD_loss
|
||||
|
||||
# Given fake and real image, return the prediction of discriminator
|
||||
# for each fake and real image.
|
||||
|
||||
def discriminate(self, input_semantics, fake_image, real_image):
|
||||
|
||||
if self.opt.no_parsing_map:
|
||||
fake_concat = fake_image
|
||||
real_concat = real_image
|
||||
else:
|
||||
fake_concat = torch.cat([input_semantics, fake_image], dim=1)
|
||||
real_concat = torch.cat([input_semantics, real_image], dim=1)
|
||||
|
||||
# In Batch Normalization, the fake and real images are
|
||||
# recommended to be in the same batch to avoid disparate
|
||||
# statistics in fake and real images.
|
||||
# So both fake and real images are fed to D all at once.
|
||||
fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
|
||||
|
||||
discriminator_out = self.netD(fake_and_real)
|
||||
|
||||
pred_fake, pred_real = self.divide_pred(discriminator_out)
|
||||
|
||||
return pred_fake, pred_real
|
||||
|
||||
# Take the prediction of fake and real images from the combined batch
|
||||
def divide_pred(self, pred):
|
||||
# the prediction contains the intermediate outputs of multiscale GAN,
|
||||
# so it's usually a list
|
||||
if type(pred) == list:
|
||||
fake = []
|
||||
real = []
|
||||
for p in pred:
|
||||
fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
|
||||
real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
|
||||
else:
|
||||
fake = pred[: pred.size(0) // 2]
|
||||
real = pred[pred.size(0) // 2 :]
|
||||
|
||||
return fake, real
|
||||
|
||||
def get_edges(self, t):
|
||||
edge = self.ByteTensor(t.size()).zero_()
|
||||
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
||||
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
||||
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
||||
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
||||
return edge.float()
|
||||
|
||||
def reparameterize(self, mu, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return eps.mul(std) + mu
|
||||
|
||||
def use_gpu(self):
|
||||
return torch.cuda.is_available() and len(self.opt.gpu_ids) > 0
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
@@ -1,347 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
from ..util import util
|
||||
from .. import models
|
||||
|
||||
|
||||
class BaseOptions:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self, parser):
|
||||
# experiment specifics
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
default="label2coco",
|
||||
help="name of the experiment. It decides where to store samples and models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gpu_ids",
|
||||
type=str,
|
||||
default="0",
|
||||
help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_dir",
|
||||
type=str,
|
||||
default="./checkpoints",
|
||||
help="models are saved here",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="pix2pix", help="which model to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--norm_G",
|
||||
type=str,
|
||||
default="spectralinstance",
|
||||
help="instance normalization or batch normalization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--norm_D",
|
||||
type=str,
|
||||
default="spectralinstance",
|
||||
help="instance normalization or batch normalization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--norm_E",
|
||||
type=str,
|
||||
default="spectralinstance",
|
||||
help="instance normalization or batch normalization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--phase", type=str, default="train", help="train, val, test, etc"
|
||||
)
|
||||
|
||||
# input/output sizes
|
||||
parser.add_argument("--batchSize", type=int, default=1, help="input batch size")
|
||||
parser.add_argument(
|
||||
"--preprocess_mode",
|
||||
type=str,
|
||||
default="scale_width_and_crop",
|
||||
help="scaling and cropping of images at load time.",
|
||||
choices=(
|
||||
"resize_and_crop",
|
||||
"crop",
|
||||
"scale_width",
|
||||
"scale_width_and_crop",
|
||||
"scale_shortside",
|
||||
"scale_shortside_and_crop",
|
||||
"fixed",
|
||||
"none",
|
||||
"resize",
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Scale images to this size. The final image will be cropped to --crop_size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop_size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Crop to the width of crop_size (after initially scaling the images to load_size.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aspect_ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label_nc",
|
||||
type=int,
|
||||
default=182,
|
||||
help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--contain_dontcare_label",
|
||||
action="store_true",
|
||||
help="if the label map contains dontcare label (dontcare=255)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_nc", type=int, default=3, help="# of output image channels"
|
||||
)
|
||||
|
||||
# for setting inputs
|
||||
parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/")
|
||||
parser.add_argument("--dataset_mode", type=str, default="coco")
|
||||
parser.add_argument(
|
||||
"--serial_batches",
|
||||
action="store_true",
|
||||
help="if true, takes images in order to make batches, otherwise takes them randomly",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_flip",
|
||||
action="store_true",
|
||||
help="if specified, do not flip the images for data argumentation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nThreads", default=0, type=int, help="# threads for loading data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_dataset_size",
|
||||
type=int,
|
||||
default=sys.maxsize,
|
||||
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_from_opt_file",
|
||||
action="store_true",
|
||||
help="load the options from checkpoints and use that as default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_filelist_write",
|
||||
action="store_true",
|
||||
help="saves the current filelist into a text file, so that it loads faster",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_filelist_read",
|
||||
action="store_true",
|
||||
help="reads from the file list cache",
|
||||
)
|
||||
|
||||
# for displays
|
||||
parser.add_argument(
|
||||
"--display_winsize", type=int, default=400, help="display window size"
|
||||
)
|
||||
|
||||
# for generator
|
||||
parser.add_argument(
|
||||
"--netG",
|
||||
type=str,
|
||||
default="spade",
|
||||
help="selects model to use for netG (pix2pixhd | spade)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ngf", type=int, default=64, help="# of gen filters in first conv layer"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_type",
|
||||
type=str,
|
||||
default="xavier",
|
||||
help="network initialization [normal|xavier|kaiming|orthogonal]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_variance",
|
||||
type=float,
|
||||
default=0.02,
|
||||
help="variance of the initialization distribution",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--z_dim", type=int, default=256, help="dimension of the latent z vector"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_parsing_map",
|
||||
action="store_true",
|
||||
help="During training, we do not use the parsing map",
|
||||
)
|
||||
|
||||
# for instance-wise features
|
||||
parser.add_argument(
|
||||
"--no_instance",
|
||||
action="store_true",
|
||||
help="if specified, do *not* add instance map as input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nef",
|
||||
type=int,
|
||||
default=16,
|
||||
help="# of encoder filters in the first conv layer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_vae",
|
||||
action="store_true",
|
||||
help="enable training with an image encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensorboard_log",
|
||||
action="store_true",
|
||||
help="use tensorboard to record the resutls",
|
||||
)
|
||||
|
||||
# parser.add_argument('--img_dir',)
|
||||
parser.add_argument(
|
||||
"--old_face_folder",
|
||||
type=str,
|
||||
default="",
|
||||
help="The folder name of input old face",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_face_label_folder",
|
||||
type=str,
|
||||
default="",
|
||||
help="The folder name of input old face label",
|
||||
)
|
||||
|
||||
parser.add_argument("--injection_layer", type=str, default="all", help="")
|
||||
|
||||
self.initialized = True
|
||||
return parser
|
||||
|
||||
def gather_options(self):
|
||||
# initialize parser with basic options
|
||||
if not self.initialized:
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser = self.initialize(parser)
|
||||
|
||||
# get the basic options
|
||||
opt, unknown = parser.parse_known_args()
|
||||
|
||||
# modify model-related parser options
|
||||
model_name = opt.model
|
||||
model_option_setter = models.get_option_setter(model_name)
|
||||
parser = model_option_setter(parser, self.isTrain)
|
||||
|
||||
# modify dataset-related parser options
|
||||
# dataset_mode = opt.dataset_mode
|
||||
# dataset_option_setter = data.get_option_setter(dataset_mode)
|
||||
# parser = dataset_option_setter(parser, self.isTrain)
|
||||
|
||||
opt, unknown = parser.parse_known_args()
|
||||
|
||||
# if there is opt_file, load it.
|
||||
# The previous default options will be overwritten
|
||||
if opt.load_from_opt_file:
|
||||
parser = self.update_options_from_file(parser, opt)
|
||||
|
||||
opt, unknown = parser.parse_known_args()
|
||||
self.parser = parser
|
||||
return opt
|
||||
|
||||
def print_options(self, opt):
|
||||
message = ""
|
||||
message += "----------------- Options ---------------\n"
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ""
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = "\t[default: %s]" % str(default)
|
||||
message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
|
||||
message += "----------------- End -------------------"
|
||||
# print(message)
|
||||
|
||||
def option_file_path(self, opt, makedir=False):
|
||||
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
if makedir:
|
||||
util.mkdirs(expr_dir)
|
||||
file_name = os.path.join(expr_dir, "opt")
|
||||
return file_name
|
||||
|
||||
def save_options(self, opt):
|
||||
file_name = self.option_file_path(opt, makedir=True)
|
||||
with open(file_name + ".txt", "wt") as opt_file:
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ""
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = "\t[default: %s]" % str(default)
|
||||
opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment))
|
||||
|
||||
with open(file_name + ".pkl", "wb") as opt_file:
|
||||
pickle.dump(opt, opt_file)
|
||||
|
||||
def update_options_from_file(self, parser, opt):
|
||||
new_opt = self.load_options(opt)
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
if hasattr(new_opt, k) and v != getattr(new_opt, k):
|
||||
new_val = getattr(new_opt, k)
|
||||
parser.set_defaults(**{k: new_val})
|
||||
return parser
|
||||
|
||||
def load_options(self, opt):
|
||||
file_name = self.option_file_path(opt, makedir=False)
|
||||
new_opt = pickle.load(open(file_name + ".pkl", "rb"))
|
||||
return new_opt
|
||||
|
||||
def parse(self):
|
||||
|
||||
opt = self.gather_options()
|
||||
opt.isTrain = self.isTrain # train or test
|
||||
opt.contain_dontcare_label = False
|
||||
|
||||
self.print_options(opt)
|
||||
if opt.isTrain:
|
||||
self.save_options(opt)
|
||||
|
||||
# Set semantic_nc based on the option.
|
||||
# This will be convenient in many places
|
||||
opt.semantic_nc = (
|
||||
opt.label_nc
|
||||
+ (1 if opt.contain_dontcare_label else 0)
|
||||
+ (0 if opt.no_instance else 1)
|
||||
)
|
||||
|
||||
str_ids = opt.gpu_ids.split(",")
|
||||
opt.gpu_ids = []
|
||||
for str_id in str_ids:
|
||||
int_id = int(str_id)
|
||||
if int_id >= 0:
|
||||
opt.gpu_ids.append(int_id)
|
||||
|
||||
# set gpu ids
|
||||
if torch.cuda.is_available() and len(opt.gpu_ids) > 0:
|
||||
if torch.cuda.device_count() > opt.gpu_ids[0]:
|
||||
try:
|
||||
torch.cuda.set_device(opt.gpu_ids[0])
|
||||
except:
|
||||
print("Failed to set GPU device. Using CPU...")
|
||||
|
||||
else:
|
||||
print("Invalid GPU ID. Using CPU...")
|
||||
|
||||
# assert (len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids))
|
||||
|
||||
self.opt = opt
|
||||
return self.opt
|
||||
@@ -1,26 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TestOptions(BaseOptions):
|
||||
def initialize(self, parser):
|
||||
BaseOptions.initialize(self, parser)
|
||||
parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.")
|
||||
parser.add_argument(
|
||||
"--which_epoch",
|
||||
type=str,
|
||||
default="latest",
|
||||
help="which epoch to load? set to latest to use latest cached model",
|
||||
)
|
||||
parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run")
|
||||
|
||||
parser.set_defaults(
|
||||
preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256
|
||||
)
|
||||
parser.set_defaults(serial_batches=True)
|
||||
parser.set_defaults(no_flip=True)
|
||||
parser.set_defaults(phase="test")
|
||||
self.isTrain = False
|
||||
return parser
|
||||
@@ -1,34 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
import torchvision.transforms as T
|
||||
import warnings
|
||||
|
||||
from .options.test_options import TestOptions
|
||||
from .models.pix2pix_model import Pix2PixModel
|
||||
from .data import create_dataloader
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
tensor2image = T.ToPILImage()
|
||||
|
||||
|
||||
def test_face(face_images: list, args: dict) -> list:
|
||||
opt = TestOptions().parse()
|
||||
for K, V in args.items():
|
||||
setattr(opt, K, V)
|
||||
|
||||
dataloader = create_dataloader(opt, face_images)
|
||||
images = []
|
||||
|
||||
model = Pix2PixModel(opt)
|
||||
model.eval()
|
||||
|
||||
for i, data_i in enumerate(dataloader):
|
||||
if i * opt.batchSize >= opt.how_many:
|
||||
break
|
||||
|
||||
generated = model(data_i, mode="inference")
|
||||
|
||||
for b in range(generated.shape[0]):
|
||||
images.append(tensor2image((generated[b] + 1) / 2))
|
||||
|
||||
return images
|
||||
@@ -1,2 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
@@ -1,74 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Helper class that keeps track of training iterations
|
||||
class IterationCounter:
|
||||
def __init__(self, opt, dataset_size):
|
||||
self.opt = opt
|
||||
self.dataset_size = dataset_size
|
||||
|
||||
self.first_epoch = 1
|
||||
self.total_epochs = opt.niter + opt.niter_decay
|
||||
self.epoch_iter = 0 # iter number within each epoch
|
||||
self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt")
|
||||
if opt.isTrain and opt.continue_train:
|
||||
try:
|
||||
self.first_epoch, self.epoch_iter = np.loadtxt(
|
||||
self.iter_record_path, delimiter=",", dtype=int
|
||||
)
|
||||
print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter))
|
||||
except:
|
||||
print(
|
||||
"Could not load iteration record at %s. Starting from beginning." % self.iter_record_path
|
||||
)
|
||||
|
||||
self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter
|
||||
|
||||
# return the iterator of epochs for the training
|
||||
def training_epochs(self):
|
||||
return range(self.first_epoch, self.total_epochs + 1)
|
||||
|
||||
def record_epoch_start(self, epoch):
|
||||
self.epoch_start_time = time.time()
|
||||
self.epoch_iter = 0
|
||||
self.last_iter_time = time.time()
|
||||
self.current_epoch = epoch
|
||||
|
||||
def record_one_iteration(self):
|
||||
current_time = time.time()
|
||||
|
||||
# the last remaining batch is dropped (see data/__init__.py),
|
||||
# so we can assume batch size is always opt.batchSize
|
||||
self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
|
||||
self.last_iter_time = current_time
|
||||
self.total_steps_so_far += self.opt.batchSize
|
||||
self.epoch_iter += self.opt.batchSize
|
||||
|
||||
def record_epoch_end(self):
|
||||
current_time = time.time()
|
||||
self.time_per_epoch = current_time - self.epoch_start_time
|
||||
print(
|
||||
"End of epoch %d / %d \t Time Taken: %d sec"
|
||||
% (self.current_epoch, self.total_epochs, self.time_per_epoch)
|
||||
)
|
||||
if self.current_epoch % self.opt.save_epoch_freq == 0:
|
||||
np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d")
|
||||
print("Saved current iteration count at %s." % self.iter_record_path)
|
||||
|
||||
def record_current_iter(self):
|
||||
np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d")
|
||||
print("Saved current iteration count at %s." % self.iter_record_path)
|
||||
|
||||
def needs_saving(self):
|
||||
return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
|
||||
|
||||
def needs_printing(self):
|
||||
return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
|
||||
|
||||
def needs_displaying(self):
|
||||
return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize
|
||||
@@ -1,217 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import importlib
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
import argparse
|
||||
import dill as pickle
|
||||
|
||||
|
||||
def save_obj(obj, name):
|
||||
with open(name, "wb") as f:
|
||||
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
def load_obj(name):
|
||||
with open(name, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def copyconf(default_opt, **kwargs):
|
||||
conf = argparse.Namespace(**vars(default_opt))
|
||||
for key in kwargs:
|
||||
print(key, kwargs[key])
|
||||
setattr(conf, key, kwargs[key])
|
||||
return conf
|
||||
|
||||
|
||||
# Converts a Tensor into a Numpy array
|
||||
# |imtype|: the desired type of the converted numpy array
|
||||
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
|
||||
if isinstance(image_tensor, list):
|
||||
image_numpy = []
|
||||
for i in range(len(image_tensor)):
|
||||
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
|
||||
return image_numpy
|
||||
|
||||
if image_tensor.dim() == 4:
|
||||
# transform each image in the batch
|
||||
images_np = []
|
||||
for b in range(image_tensor.size(0)):
|
||||
one_image = image_tensor[b]
|
||||
one_image_np = tensor2im(one_image)
|
||||
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
|
||||
images_np = np.concatenate(images_np, axis=0)
|
||||
|
||||
return images_np
|
||||
|
||||
if image_tensor.dim() == 2:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
image_numpy = image_tensor.detach().cpu().float().numpy()
|
||||
if normalize:
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
else:
|
||||
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
||||
image_numpy = np.clip(image_numpy, 0, 255)
|
||||
if image_numpy.shape[2] == 1:
|
||||
image_numpy = image_numpy[:, :, 0]
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
|
||||
# Converts a one-hot tensor into a colorful label map
|
||||
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
|
||||
if label_tensor.dim() == 4:
|
||||
# transform each image in the batch
|
||||
images_np = []
|
||||
for b in range(label_tensor.size(0)):
|
||||
one_image = label_tensor[b]
|
||||
one_image_np = tensor2label(one_image, n_label, imtype)
|
||||
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
|
||||
images_np = np.concatenate(images_np, axis=0)
|
||||
# if tile:
|
||||
# images_tiled = tile_images(images_np)
|
||||
# return images_tiled
|
||||
# else:
|
||||
# images_np = images_np[0]
|
||||
# return images_np
|
||||
return images_np
|
||||
|
||||
if label_tensor.dim() == 1:
|
||||
return np.zeros((64, 64, 3), dtype=np.uint8)
|
||||
if n_label == 0:
|
||||
return tensor2im(label_tensor, imtype)
|
||||
label_tensor = label_tensor.cpu().float()
|
||||
if label_tensor.size()[0] > 1:
|
||||
label_tensor = label_tensor.max(0, keepdim=True)[1]
|
||||
label_tensor = Colorize(n_label)(label_tensor)
|
||||
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
|
||||
result = label_numpy.astype(imtype)
|
||||
return result
|
||||
|
||||
|
||||
def save_image(image_numpy, image_path, create_dir=False):
|
||||
if create_dir:
|
||||
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
||||
if len(image_numpy.shape) == 2:
|
||||
image_numpy = np.expand_dims(image_numpy, axis=2)
|
||||
if image_numpy.shape[2] == 1:
|
||||
image_numpy = np.repeat(image_numpy, 3, 2)
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
|
||||
# save to png
|
||||
image_pil.save(image_path.replace(".jpg", ".png"))
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def atoi(text):
|
||||
return int(text) if text.isdigit() else text
|
||||
|
||||
|
||||
def natural_keys(text):
|
||||
"""
|
||||
alist.sort(key=natural_keys) sorts in human order
|
||||
http://nedbatchelder.com/blog/200712/human_sorting.html
|
||||
(See Toothy's implementation in the comments)
|
||||
"""
|
||||
return [atoi(c) for c in re.split("(\d+)", text)]
|
||||
|
||||
|
||||
def natural_sort(items):
|
||||
items.sort(key=natural_keys)
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def find_class_in_module(target_cls_name, module):
|
||||
target_cls_name = target_cls_name.replace("_", "").lower()
|
||||
|
||||
if module == 'models.networks.generator':
|
||||
from ..models.networks import generator as clslib
|
||||
elif module == 'models.networks.encoder':
|
||||
from ..models.networks import encoder as clslib
|
||||
else:
|
||||
raise NotImplementedError(f'Module: {module}')
|
||||
|
||||
cls = None
|
||||
for name, clsobj in clslib.__dict__.items():
|
||||
if name.lower() == target_cls_name:
|
||||
cls = clsobj
|
||||
|
||||
if cls is None:
|
||||
print(
|
||||
"In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
|
||||
% (module, target_cls_name)
|
||||
)
|
||||
exit(0)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
def save_network(net, label, epoch, opt):
|
||||
save_filename = "%s_net_%s.pth" % (epoch, label)
|
||||
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
|
||||
torch.save(net.cpu().state_dict(), save_path)
|
||||
if len(opt.gpu_ids) and torch.cuda.is_available():
|
||||
net.cuda()
|
||||
|
||||
|
||||
def load_network(net, label, epoch, opt):
|
||||
save_filename = "%s_net_%s.pth" % (epoch, label)
|
||||
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
if os.path.exists(save_path):
|
||||
weights = torch.load(save_path)
|
||||
net.load_state_dict(weights)
|
||||
return net
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Code from
|
||||
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
|
||||
# Modified so it complies with the Citscape label map colors
|
||||
###############################################################################
|
||||
def uint82bin(n, count=8):
|
||||
"""returns the binary of integer n, count refers to amount of bits"""
|
||||
return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
|
||||
|
||||
|
||||
class Colorize(object):
|
||||
def __init__(self, n=35):
|
||||
self.cmap = labelcolormap(n)
|
||||
self.cmap = torch.from_numpy(self.cmap[:n])
|
||||
|
||||
def __call__(self, gray_image):
|
||||
size = gray_image.size()
|
||||
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
|
||||
|
||||
for label in range(0, len(self.cmap)):
|
||||
mask = (label == gray_image[0]).cpu()
|
||||
color_image[0][mask] = self.cmap[label][0]
|
||||
color_image[1][mask] = self.cmap[label][1]
|
||||
color_image[2][mask] = self.cmap[label][2]
|
||||
|
||||
return color_image
|
||||
@@ -1,134 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import ntpath
|
||||
import time
|
||||
from . import util
|
||||
import scipy.misc
|
||||
|
||||
try:
|
||||
from io import BytesIO # Python 3.x
|
||||
except ImportError:
|
||||
from StringIO import StringIO # Python 2.7
|
||||
import torchvision.utils as vutils
|
||||
from tensorboardX import SummaryWriter
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(self, opt):
|
||||
self.opt = opt
|
||||
self.tf_log = opt.isTrain and opt.tf_log
|
||||
|
||||
self.tensorboard_log = opt.tensorboard_log
|
||||
|
||||
self.win_size = opt.display_winsize
|
||||
self.name = opt.name
|
||||
if self.tensorboard_log:
|
||||
|
||||
if self.opt.isTrain:
|
||||
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
self.writer = SummaryWriter(log_dir=self.log_dir)
|
||||
else:
|
||||
# print("hi :)")
|
||||
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
|
||||
if opt.isTrain:
|
||||
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
|
||||
with open(self.log_name, "a") as log_file:
|
||||
now = time.strftime("%c")
|
||||
log_file.write("================ Training Loss (%s) ================\n" % now)
|
||||
|
||||
# |visuals|: dictionary of images to display or save
|
||||
def display_current_results(self, visuals, epoch, step):
|
||||
|
||||
all_tensor = []
|
||||
if self.tensorboard_log:
|
||||
|
||||
for key, tensor in visuals.items():
|
||||
all_tensor.append((tensor.data.cpu() + 1) / 2)
|
||||
|
||||
output = torch.cat(all_tensor, 0)
|
||||
img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)
|
||||
|
||||
if self.opt.isTrain:
|
||||
self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
|
||||
else:
|
||||
vutils.save_image(
|
||||
output,
|
||||
os.path.join(self.log_dir, str(step) + ".png"),
|
||||
nrow=self.opt.batchSize,
|
||||
padding=0,
|
||||
normalize=False,
|
||||
)
|
||||
|
||||
# errors: dictionary of error labels and values
|
||||
def plot_current_errors(self, errors, step):
|
||||
if self.tf_log:
|
||||
for tag, value in errors.items():
|
||||
value = value.mean().float()
|
||||
summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
|
||||
self.writer.add_summary(summary, step)
|
||||
|
||||
if self.tensorboard_log:
|
||||
|
||||
self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
|
||||
self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
|
||||
self.writer.add_scalars(
|
||||
"Loss/GAN",
|
||||
{
|
||||
"G": errors["GAN"].mean().float(),
|
||||
"D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
|
||||
},
|
||||
step,
|
||||
)
|
||||
|
||||
# errors: same format as |errors| of plotCurrentErrors
|
||||
def print_current_errors(self, epoch, i, errors, t):
|
||||
message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
|
||||
for k, v in errors.items():
|
||||
v = v.mean().float()
|
||||
message += "%s: %.3f " % (k, v)
|
||||
|
||||
print(message)
|
||||
with open(self.log_name, "a") as log_file:
|
||||
log_file.write("%s\n" % message)
|
||||
|
||||
def convert_visuals_to_numpy(self, visuals):
|
||||
for key, t in visuals.items():
|
||||
tile = self.opt.batchSize > 8
|
||||
if "input_label" == key:
|
||||
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy
|
||||
else:
|
||||
t = util.tensor2im(t, tile=tile)
|
||||
visuals[key] = t
|
||||
return visuals
|
||||
|
||||
# save image to the disk
|
||||
def save_images(self, webpage, visuals, image_path):
|
||||
visuals = self.convert_visuals_to_numpy(visuals)
|
||||
|
||||
image_dir = webpage.get_image_dir()
|
||||
short_path = ntpath.basename(image_path[0])
|
||||
name = os.path.splitext(short_path)[0]
|
||||
|
||||
webpage.add_header(name)
|
||||
ims = []
|
||||
txts = []
|
||||
links = []
|
||||
|
||||
for label, image_numpy in visuals.items():
|
||||
image_name = os.path.join(label, "%s.png" % (name))
|
||||
save_path = os.path.join(image_dir, image_name)
|
||||
util.save_image(image_numpy, save_path, create_dir=True)
|
||||
|
||||
ims.append(image_name)
|
||||
txts.append(label)
|
||||
links.append(image_name)
|
||||
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
@@ -1,63 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import struct
|
||||
from PIL import Image
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir):
|
||||
images = []
|
||||
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
#print(fname)
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
|
||||
return images
|
||||
|
||||
### Modify these 3 lines in your own environment
|
||||
indir="/home/ziyuwan/workspace/data/temp_old"
|
||||
target_folders=['VOC','Real_L_old','Real_RGB_old']
|
||||
out_dir ="/home/ziyuwan/workspace/data/temp_old"
|
||||
###
|
||||
|
||||
if os.path.exists(out_dir) is False:
|
||||
os.makedirs(out_dir)
|
||||
|
||||
#
|
||||
for target_folder in target_folders:
|
||||
curr_indir = os.path.join(indir, target_folder)
|
||||
curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder)))
|
||||
image_lists = make_dataset(curr_indir)
|
||||
image_lists.sort()
|
||||
with open(curr_out_file, 'wb') as wfid:
|
||||
# write total image number
|
||||
wfid.write(struct.pack('i', len(image_lists)))
|
||||
for i, img_path in enumerate(image_lists):
|
||||
# write file name first
|
||||
img_name = os.path.basename(img_path)
|
||||
img_name_bytes = img_name.encode('utf-8')
|
||||
wfid.write(struct.pack('i', len(img_name_bytes)))
|
||||
wfid.write(img_name_bytes)
|
||||
#
|
||||
# # write image data in
|
||||
with open(img_path, 'rb') as img_fid:
|
||||
img_bytes = img_fid.read()
|
||||
wfid.write(struct.pack('i', len(img_bytes)))
|
||||
wfid.write(img_bytes)
|
||||
|
||||
if i % 1000 == 0:
|
||||
print('write %d images done' % i)
|
||||
@@ -1,42 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import io
|
||||
import os
|
||||
import struct
|
||||
from PIL import Image
|
||||
|
||||
class BigFileMemoryLoader(object):
|
||||
def __load_bigfile(self):
|
||||
print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))
|
||||
with open(self.file_path, 'rb') as fid:
|
||||
self.img_num = struct.unpack('i', fid.read(4))[0]
|
||||
self.img_names = []
|
||||
self.img_bytes = []
|
||||
print('find total %d images' % self.img_num)
|
||||
for i in range(self.img_num):
|
||||
img_name_len = struct.unpack('i', fid.read(4))[0]
|
||||
img_name = fid.read(img_name_len).decode('utf-8')
|
||||
self.img_names.append(img_name)
|
||||
img_bytes_len = struct.unpack('i', fid.read(4))[0]
|
||||
self.img_bytes.append(fid.read(img_bytes_len))
|
||||
if i % 5000 == 0:
|
||||
print('load %d images done' % i)
|
||||
print('load all %d images done' % self.img_num)
|
||||
|
||||
def __init__(self, file_path):
|
||||
super(BigFileMemoryLoader, self).__init__()
|
||||
self.file_path = file_path
|
||||
self.__load_bigfile()
|
||||
|
||||
def __getitem__(self, index):
|
||||
try:
|
||||
img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')
|
||||
return self.img_names[index], img
|
||||
except Exception:
|
||||
print('Image read error for index %d: %s' % (index, self.img_names[index]))
|
||||
return self.__getitem__((index+1)%self.img_num)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.img_num
|
||||
@@ -1,16 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
class BaseDataLoader():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
pass
|
||||
|
||||
def load_data():
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
class BaseDataset(data.Dataset):
|
||||
def __init__(self):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
def name(self):
|
||||
return 'BaseDataset'
|
||||
|
||||
def initialize(self, opt):
|
||||
pass
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if opt.resize_or_crop == 'resize_and_crop':
|
||||
new_h = new_w = opt.loadSize
|
||||
|
||||
if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256
|
||||
|
||||
if w<h:
|
||||
new_w = opt.loadSize
|
||||
new_h = opt.loadSize * h // w
|
||||
else:
|
||||
new_h=opt.loadSize
|
||||
new_w = opt.loadSize * w // h
|
||||
|
||||
if opt.resize_or_crop=='crop_only':
|
||||
pass
|
||||
|
||||
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
|
||||
|
||||
flip = random.random() > 0.5
|
||||
return {'crop_pos': (x, y), 'flip': flip}
|
||||
|
||||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
|
||||
transform_list = []
|
||||
if 'resize' in opt.resize_or_crop:
|
||||
osize = [opt.loadSize, opt.loadSize]
|
||||
transform_list.append(transforms.Scale(osize, method))
|
||||
elif 'scale_width' in opt.resize_or_crop:
|
||||
# transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it.
|
||||
transform_list.append(transforms.Scale(256,method))
|
||||
|
||||
if 'crop' in opt.resize_or_crop:
|
||||
if opt.isTrain:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
|
||||
else:
|
||||
if opt.test_random_crop:
|
||||
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
||||
else:
|
||||
transform_list.append(transforms.CenterCrop(opt.fineSize))
|
||||
|
||||
## when testing, for ablation study, choose center_crop directly.
|
||||
|
||||
|
||||
|
||||
if opt.resize_or_crop == 'none':
|
||||
base = float(2 ** opt.n_downsample_global)
|
||||
if opt.netG == 'local':
|
||||
base *= (2 ** opt.n_local_enhancers)
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
||||
|
||||
if opt.isTrain and not opt.no_flip:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
|
||||
transform_list += [transforms.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
||||
(0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def normalize():
|
||||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __scale_width(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if (ow == target_width):
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __crop(img, pos, size):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
if (ow > tw or oh > th):
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
return img
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
@@ -1,41 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.utils.data
|
||||
import random
|
||||
from data.base_data_loader import BaseDataLoader
|
||||
from data import online_dataset_for_old_photos as dts_ray_bigfile
|
||||
|
||||
|
||||
def CreateDataset(opt):
|
||||
dataset = None
|
||||
if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':
|
||||
dataset = dts_ray_bigfile.UnPairOldPhotos_SR()
|
||||
if opt.training_dataset=='mapping':
|
||||
if opt.random_hole:
|
||||
dataset = dts_ray_bigfile.PairOldPhotos_with_hole()
|
||||
else:
|
||||
dataset = dts_ray_bigfile.PairOldPhotos()
|
||||
print("dataset [%s] was created" % (dataset.name()))
|
||||
dataset.initialize(opt)
|
||||
return dataset
|
||||
|
||||
class CustomDatasetDataLoader(BaseDataLoader):
|
||||
def name(self):
|
||||
return 'CustomDatasetDataLoader'
|
||||
|
||||
def initialize(self, opt):
|
||||
BaseDataLoader.initialize(self, opt)
|
||||
self.dataset = CreateDataset(opt)
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=opt.batchSize,
|
||||
shuffle=not opt.serial_batches,
|
||||
num_workers=int(opt.nThreads),
|
||||
drop_last=True)
|
||||
|
||||
def load_data(self):
|
||||
return self.dataloader
|
||||
|
||||
def __len__(self):
|
||||
return min(len(self.dataset), self.opt.max_dataset_size)
|
||||
@@ -1,9 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
def CreateDataLoader(opt):
|
||||
from data.custom_dataset_data_loader import CustomDatasetDataLoader
|
||||
data_loader = CustomDatasetDataLoader()
|
||||
print(data_loader.name())
|
||||
data_loader.initialize(opt)
|
||||
return data_loader
|
||||
@@ -1,62 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir):
|
||||
images = []
|
||||
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
|
||||
class ImageFolder(data.Dataset):
|
||||
|
||||
def __init__(self, root, transform=None, return_paths=False,
|
||||
loader=default_loader):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
||||
"Supported image extensions are: " +
|
||||
",".join(IMG_EXTENSIONS)))
|
||||
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.return_paths = return_paths
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.imgs[index]
|
||||
img = self.loader(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.return_paths:
|
||||
return img, path
|
||||
else:
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
@@ -1,144 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image, ImageFile
|
||||
import torch.nn.functional as F
|
||||
import torchvision as tv
|
||||
import numpy as np
|
||||
import warnings
|
||||
import argparse
|
||||
import torch
|
||||
import gc
|
||||
import os
|
||||
|
||||
from .detection_models import networks
|
||||
|
||||
|
||||
tensor2image = transforms.ToPILImage()
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
def data_transforms(img, full_size, method=Image.BICUBIC):
|
||||
if full_size == "full_size":
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / 16) * 16)
|
||||
w = int(round(ow / 16) * 16)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
elif full_size == "scale_256":
|
||||
ow, oh = img.size
|
||||
pw, ph = ow, oh
|
||||
if ow < oh:
|
||||
ow = 256
|
||||
oh = ph / pw * 256
|
||||
else:
|
||||
oh = 256
|
||||
ow = pw / ph * 256
|
||||
|
||||
h = int(round(oh / 16) * 16)
|
||||
w = int(round(ow / 16) * 16)
|
||||
if (h == ph) and (w == pw):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def scale_tensor(img_tensor, default_scale=256):
|
||||
_, _, w, h = img_tensor.shape
|
||||
if w < h:
|
||||
ow = default_scale
|
||||
oh = h / w * default_scale
|
||||
else:
|
||||
oh = default_scale
|
||||
ow = w / h * default_scale
|
||||
|
||||
oh = int(round(oh / 16) * 16)
|
||||
ow = int(round(ow / 16) * 16)
|
||||
|
||||
return F.interpolate(img_tensor, [ow, oh], mode="bilinear")
|
||||
|
||||
|
||||
def blend_mask(img, mask):
|
||||
|
||||
np_img = np.array(img).astype("float")
|
||||
|
||||
return Image.fromarray(
|
||||
(np_img * (1 - mask) + mask * 255.0).astype("uint8")
|
||||
).convert("RGB")
|
||||
|
||||
|
||||
def main(config: argparse.Namespace, input_image: Image):
|
||||
|
||||
model = networks.UNet(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
depth=4,
|
||||
conv_num=2,
|
||||
wf=6,
|
||||
padding=True,
|
||||
batch_norm=True,
|
||||
up_mode="upsample",
|
||||
with_tanh=False,
|
||||
sync_bn=True,
|
||||
antialiasing=True,
|
||||
)
|
||||
|
||||
## load model
|
||||
checkpoint_path = os.path.join(
|
||||
os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt"
|
||||
)
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model_state"])
|
||||
print("model weights loaded")
|
||||
|
||||
if torch.cuda.is_available() and config.GPU >= 0:
|
||||
model.to(config.GPU)
|
||||
else:
|
||||
model.cpu()
|
||||
|
||||
model.eval()
|
||||
|
||||
print("processing...")
|
||||
|
||||
transformed_image_PIL = data_transforms(input_image, config.input_size)
|
||||
input_image = transformed_image_PIL.convert("L")
|
||||
input_image = tv.transforms.ToTensor()(input_image)
|
||||
input_image = tv.transforms.Normalize([0.5], [0.5])(input_image)
|
||||
input_image = torch.unsqueeze(input_image, 0)
|
||||
|
||||
_, _, ow, oh = input_image.shape
|
||||
scratch_image_scale = scale_tensor(input_image)
|
||||
|
||||
if torch.cuda.is_available() and config.GPU >= 0:
|
||||
scratch_image_scale = scratch_image_scale.to(config.GPU)
|
||||
else:
|
||||
scratch_image_scale = scratch_image_scale.cpu()
|
||||
|
||||
with torch.no_grad():
|
||||
P = torch.sigmoid(model(scratch_image_scale))
|
||||
|
||||
P = P.data.cpu()
|
||||
P = F.interpolate(P, [ow, oh], mode="nearest")
|
||||
|
||||
scratch_mask = torch.clamp((P >= 0.4).float(), 0.0, 1.0) * 255
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return tensor2image(scratch_mask[0].byte()), transformed_image_PIL
|
||||
|
||||
|
||||
def global_detection(
|
||||
input_image: Image,
|
||||
gpu: int,
|
||||
input_size: str,
|
||||
) -> Image:
|
||||
|
||||
config = argparse.Namespace()
|
||||
config.GPU = gpu
|
||||
config.input_size = input_size
|
||||
|
||||
return main(config, input_image)
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn.parallel
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
# https://github.com/adobe/antialiased-cnns
|
||||
|
||||
def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0):
|
||||
super(Downsample, self).__init__()
|
||||
self.filt_size = filt_size
|
||||
self.pad_off = pad_off
|
||||
self.pad_sizes = [
|
||||
int(1.0 * (filt_size - 1) / 2),
|
||||
int(np.ceil(1.0 * (filt_size - 1) / 2)),
|
||||
int(1.0 * (filt_size - 1) / 2),
|
||||
int(np.ceil(1.0 * (filt_size - 1) / 2)),
|
||||
]
|
||||
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
||||
self.stride = stride
|
||||
self.off = int((self.stride - 1) / 2.0)
|
||||
self.channels = channels
|
||||
|
||||
# print('Filter size [%i]'%filt_size)
|
||||
if self.filt_size == 1:
|
||||
a = np.array([1.0,])
|
||||
elif self.filt_size == 2:
|
||||
a = np.array([1.0, 1.0])
|
||||
elif self.filt_size == 3:
|
||||
a = np.array([1.0, 2.0, 1.0])
|
||||
elif self.filt_size == 4:
|
||||
a = np.array([1.0, 3.0, 3.0, 1.0])
|
||||
elif self.filt_size == 5:
|
||||
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||
elif self.filt_size == 6:
|
||||
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
|
||||
elif self.filt_size == 7:
|
||||
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
|
||||
|
||||
filt = torch.Tensor(a[:, None] * a[None, :])
|
||||
filt = filt / torch.sum(filt)
|
||||
self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||
|
||||
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
||||
|
||||
def forward(self, inp):
|
||||
if self.filt_size == 1:
|
||||
if self.pad_off == 0:
|
||||
return inp[:, :, :: self.stride, :: self.stride]
|
||||
else:
|
||||
return self.pad(inp)[:, :, :: self.stride, :: self.stride]
|
||||
else:
|
||||
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
||||
|
||||
|
||||
def get_pad_layer(pad_type):
|
||||
if pad_type in ["refl", "reflect"]:
|
||||
PadLayer = nn.ReflectionPad2d
|
||||
elif pad_type in ["repl", "replicate"]:
|
||||
PadLayer = nn.ReplicationPad2d
|
||||
elif pad_type == "zero":
|
||||
PadLayer = nn.ZeroPad2d
|
||||
else:
|
||||
print("Pad type [%s] not recognized" % pad_type)
|
||||
return PadLayer
|
||||
@@ -1,332 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .sync_batchnorm import DataParallelWithCallback
|
||||
from .antialiasing import Downsample
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
depth=5,
|
||||
conv_num=2,
|
||||
wf=6,
|
||||
padding=True,
|
||||
batch_norm=True,
|
||||
up_mode="upsample",
|
||||
with_tanh=False,
|
||||
sync_bn=True,
|
||||
antialiasing=True,
|
||||
):
|
||||
"""
|
||||
Implementation of
|
||||
U-Net: Convolutional Networks for Biomedical Image Segmentation
|
||||
(Ronneberger et al., 2015)
|
||||
https://arxiv.org/abs/1505.04597
|
||||
Using the default arguments will yield the exact version used
|
||||
in the original paper
|
||||
Args:
|
||||
in_channels (int): number of input channels
|
||||
out_channels (int): number of output channels
|
||||
depth (int): depth of the network
|
||||
wf (int): number of filters in the first layer is 2**wf
|
||||
padding (bool): if True, apply padding such that the input shape
|
||||
is the same as the output.
|
||||
This may introduce artifacts
|
||||
batch_norm (bool): Use BatchNorm after layers with an
|
||||
activation function
|
||||
up_mode (str): one of 'upconv' or 'upsample'.
|
||||
'upconv' will use transposed convolutions for
|
||||
learned upsampling.
|
||||
'upsample' will use bilinear upsampling.
|
||||
"""
|
||||
super().__init__()
|
||||
assert up_mode in ("upconv", "upsample")
|
||||
self.padding = padding
|
||||
self.depth = depth - 1
|
||||
prev_channels = in_channels
|
||||
|
||||
self.first = nn.Sequential(
|
||||
*[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)]
|
||||
)
|
||||
prev_channels = 2 ** wf
|
||||
|
||||
self.down_path = nn.ModuleList()
|
||||
self.down_sample = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
if antialiasing and depth > 0:
|
||||
self.down_sample.append(
|
||||
nn.Sequential(
|
||||
*[
|
||||
nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0),
|
||||
nn.BatchNorm2d(prev_channels),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
Downsample(channels=prev_channels, stride=2),
|
||||
]
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_sample.append(
|
||||
nn.Sequential(
|
||||
*[
|
||||
nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0),
|
||||
nn.BatchNorm2d(prev_channels),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.down_path.append(
|
||||
UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm)
|
||||
)
|
||||
prev_channels = 2 ** (wf + i + 1)
|
||||
|
||||
self.up_path = nn.ModuleList()
|
||||
for i in reversed(range(depth)):
|
||||
self.up_path.append(
|
||||
UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
|
||||
)
|
||||
prev_channels = 2 ** (wf + i)
|
||||
|
||||
if with_tanh:
|
||||
self.last = nn.Sequential(
|
||||
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()]
|
||||
)
|
||||
else:
|
||||
self.last = nn.Sequential(
|
||||
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)]
|
||||
)
|
||||
|
||||
if sync_bn:
|
||||
self = DataParallelWithCallback(self)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first(x)
|
||||
|
||||
blocks = []
|
||||
for i, down_block in enumerate(self.down_path):
|
||||
blocks.append(x)
|
||||
x = self.down_sample[i](x)
|
||||
x = down_block(x)
|
||||
|
||||
for i, up in enumerate(self.up_path):
|
||||
x = up(x, blocks[-i - 1])
|
||||
|
||||
return self.last(x)
|
||||
|
||||
|
||||
class UNetConvBlock(nn.Module):
|
||||
def __init__(self, conv_num, in_size, out_size, padding, batch_norm):
|
||||
super(UNetConvBlock, self).__init__()
|
||||
block = []
|
||||
|
||||
for _ in range(conv_num):
|
||||
block.append(nn.ReflectionPad2d(padding=int(padding)))
|
||||
block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0))
|
||||
if batch_norm:
|
||||
block.append(nn.BatchNorm2d(out_size))
|
||||
block.append(nn.LeakyReLU(0.2, True))
|
||||
in_size = out_size
|
||||
|
||||
self.block = nn.Sequential(*block)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.block(x)
|
||||
return out
|
||||
|
||||
|
||||
class UNetUpBlock(nn.Module):
|
||||
def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm):
|
||||
super(UNetUpBlock, self).__init__()
|
||||
if up_mode == "upconv":
|
||||
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
|
||||
elif up_mode == "upsample":
|
||||
self.up = nn.Sequential(
|
||||
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False),
|
||||
nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(in_size, out_size, kernel_size=3, padding=0),
|
||||
)
|
||||
|
||||
self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm)
|
||||
|
||||
def center_crop(self, layer, target_size):
|
||||
_, _, layer_height, layer_width = layer.size()
|
||||
diff_y = (layer_height - target_size[0]) // 2
|
||||
diff_x = (layer_width - target_size[1]) // 2
|
||||
return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]
|
||||
|
||||
def forward(self, x, bridge):
|
||||
up = self.up(x)
|
||||
crop1 = self.center_crop(bridge, up.shape[2:])
|
||||
out = torch.cat([up, crop1], 1)
|
||||
out = self.conv_block(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class UnetGenerator(nn.Module):
|
||||
"""Create a Unet-based generator"""
|
||||
|
||||
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False):
|
||||
"""Construct a Unet generator
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
output_nc (int) -- the number of channels in output images
|
||||
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
||||
image of size 128x128 will become of size 1x1 # at the bottleneck
|
||||
ngf (int) -- the number of filters in the last conv layer
|
||||
norm_layer -- normalization layer
|
||||
We construct the U-Net from the innermost layer to the outermost layer.
|
||||
It is a recursive process.
|
||||
"""
|
||||
super().__init__()
|
||||
if norm_type == "BN":
|
||||
norm_layer = nn.BatchNorm2d
|
||||
elif norm_type == "IN":
|
||||
norm_layer = nn.InstanceNorm2d
|
||||
else:
|
||||
raise NameError("Unknown norm layer")
|
||||
|
||||
# construct unet structure
|
||||
unet_block = UnetSkipConnectionBlock(
|
||||
ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True
|
||||
) # add the innermost layer
|
||||
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
||||
unet_block = UnetSkipConnectionBlock(
|
||||
ngf * 8,
|
||||
ngf * 8,
|
||||
input_nc=None,
|
||||
submodule=unet_block,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=use_dropout,
|
||||
)
|
||||
# gradually reduce the number of filters from ngf * 8 to ngf
|
||||
unet_block = UnetSkipConnectionBlock(
|
||||
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
|
||||
)
|
||||
unet_block = UnetSkipConnectionBlock(
|
||||
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
|
||||
)
|
||||
unet_block = UnetSkipConnectionBlock(
|
||||
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
|
||||
)
|
||||
self.model = UnetSkipConnectionBlock(
|
||||
output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer
|
||||
) # add the outermost layer
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
class UnetSkipConnectionBlock(nn.Module):
|
||||
"""Defines the Unet submodule with skip connection.
|
||||
|
||||
-------------------identity----------------------
|
||||
|-- downsampling -- |submodule| -- upsampling --|
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_nc,
|
||||
inner_nc,
|
||||
input_nc=None,
|
||||
submodule=None,
|
||||
outermost=False,
|
||||
innermost=False,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
use_dropout=False,
|
||||
):
|
||||
"""Construct a Unet submodule with skip connections.
|
||||
Parameters:
|
||||
outer_nc (int) -- the number of filters in the outer conv layer
|
||||
inner_nc (int) -- the number of filters in the inner conv layer
|
||||
input_nc (int) -- the number of channels in input images/features
|
||||
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
||||
outermost (bool) -- if this module is the outermost module
|
||||
innermost (bool) -- if this module is the innermost module
|
||||
norm_layer -- normalization layer
|
||||
user_dropout (bool) -- if use dropout layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.outermost = outermost
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
if input_nc is None:
|
||||
input_nc = outer_nc
|
||||
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
|
||||
downrelu = nn.LeakyReLU(0.2, True)
|
||||
downnorm = norm_layer(inner_nc)
|
||||
uprelu = nn.LeakyReLU(0.2, True)
|
||||
upnorm = norm_layer(outer_nc)
|
||||
|
||||
if outermost:
|
||||
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
|
||||
down = [downconv]
|
||||
up = [uprelu, upconv, nn.Tanh()]
|
||||
model = down + [submodule] + up
|
||||
elif innermost:
|
||||
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
|
||||
down = [downrelu, downconv]
|
||||
up = [uprelu, upconv, upnorm]
|
||||
model = down + up
|
||||
else:
|
||||
upconv = nn.ConvTranspose2d(
|
||||
inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
|
||||
)
|
||||
down = [downrelu, downconv, downnorm]
|
||||
up = [uprelu, upconv, upnorm]
|
||||
|
||||
if use_dropout:
|
||||
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
||||
else:
|
||||
model = down + [submodule] + up
|
||||
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x):
|
||||
if self.outermost:
|
||||
return self.model(x)
|
||||
else: # add skip connections
|
||||
return torch.cat([x, self.model(x)], 1)
|
||||
|
||||
|
||||
# ============================================
|
||||
# Network testing
|
||||
# ============================================
|
||||
if __name__ == "__main__":
|
||||
from torchsummary import summary
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = UNet_two_decoders(
|
||||
in_channels=3,
|
||||
out_channels1=3,
|
||||
out_channels2=1,
|
||||
depth=4,
|
||||
conv_num=1,
|
||||
wf=6,
|
||||
padding=True,
|
||||
batch_norm=True,
|
||||
up_mode="upsample",
|
||||
with_tanh=False,
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False)
|
||||
model_pix2pix.to(device)
|
||||
|
||||
print("customized unet:")
|
||||
summary(model, (3, 256, 256))
|
||||
|
||||
print("cyclegan unet:")
|
||||
summary(model_pix2pix, (3, 256, 256))
|
||||
|
||||
x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda()
|
||||
g = make_dot(model(x))
|
||||
g.render("models/Digraph.gv", view=False)
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : __init__.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
from .batchnorm import set_sbn_eps_mode
|
||||
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
||||
from .batchnorm import patch_sync_batchnorm, convert_model
|
||||
from .replicate import DataParallelWithCallback, patch_replication_callback
|
||||
@@ -1,412 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
try:
|
||||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||
except ImportError:
|
||||
ReduceAddCoalesced = Broadcast = None
|
||||
|
||||
try:
|
||||
from jactorch.parallel.comm import SyncMaster
|
||||
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
|
||||
except ImportError:
|
||||
from .comm import SyncMaster
|
||||
from .replicate import DataParallelWithCallback
|
||||
|
||||
__all__ = [
|
||||
'set_sbn_eps_mode',
|
||||
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
|
||||
'patch_sync_batchnorm', 'convert_model'
|
||||
]
|
||||
|
||||
|
||||
SBN_EPS_MODE = 'clamp'
|
||||
|
||||
|
||||
def set_sbn_eps_mode(mode):
|
||||
global SBN_EPS_MODE
|
||||
assert mode in ('clamp', 'plus')
|
||||
SBN_EPS_MODE = mode
|
||||
|
||||
|
||||
def _sum_ft(tensor):
|
||||
"""sum over the first and last dimention"""
|
||||
return tensor.sum(dim=0).sum(dim=-1)
|
||||
|
||||
|
||||
def _unsqueeze_ft(tensor):
|
||||
"""add new dimensions at the front and the tail"""
|
||||
return tensor.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
|
||||
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
||||
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
||||
|
||||
|
||||
class _SynchronizedBatchNorm(_BatchNorm):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
|
||||
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
|
||||
|
||||
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
|
||||
track_running_stats=track_running_stats)
|
||||
|
||||
if not self.track_running_stats:
|
||||
import warnings
|
||||
warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
|
||||
|
||||
self._sync_master = SyncMaster(self._data_parallel_master)
|
||||
|
||||
self._is_parallel = False
|
||||
self._parallel_id = None
|
||||
self._slave_pipe = None
|
||||
|
||||
def forward(self, input):
|
||||
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
||||
if not (self._is_parallel and self.training):
|
||||
return F.batch_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training, self.momentum, self.eps)
|
||||
|
||||
# Resize the input to (B, C, -1).
|
||||
input_shape = input.size()
|
||||
assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
|
||||
input = input.view(input.size(0), self.num_features, -1)
|
||||
|
||||
# Compute the sum and square-sum.
|
||||
sum_size = input.size(0) * input.size(2)
|
||||
input_sum = _sum_ft(input)
|
||||
input_ssum = _sum_ft(input ** 2)
|
||||
|
||||
# Reduce-and-broadcast the statistics.
|
||||
if self._parallel_id == 0:
|
||||
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
else:
|
||||
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
|
||||
# Compute the output.
|
||||
if self.affine:
|
||||
# MJY:: Fuse the multiplication for speed.
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
||||
else:
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
||||
|
||||
# Reshape it.
|
||||
return output.view(input_shape)
|
||||
|
||||
def __data_parallel_replicate__(self, ctx, copy_id):
|
||||
self._is_parallel = True
|
||||
self._parallel_id = copy_id
|
||||
|
||||
# parallel_id == 0 means master device.
|
||||
if self._parallel_id == 0:
|
||||
ctx.sync_master = self._sync_master
|
||||
else:
|
||||
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
||||
|
||||
def _data_parallel_master(self, intermediates):
|
||||
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
||||
|
||||
# Always using same "device order" makes the ReduceAdd operation faster.
|
||||
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
||||
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
||||
|
||||
to_reduce = [i[1][:2] for i in intermediates]
|
||||
to_reduce = [j for i in to_reduce for j in i] # flatten
|
||||
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
||||
|
||||
sum_size = sum([i[1].sum_size for i in intermediates])
|
||||
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
||||
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
||||
|
||||
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
||||
|
||||
outputs = []
|
||||
for i, rec in enumerate(intermediates):
|
||||
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
||||
|
||||
return outputs
|
||||
|
||||
def _compute_mean_std(self, sum_, ssum, size):
|
||||
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
||||
also maintains the moving average on the master device."""
|
||||
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
||||
mean = sum_ / size
|
||||
sumvar = ssum - sum_ * mean
|
||||
unbias_var = sumvar / (size - 1)
|
||||
bias_var = sumvar / size
|
||||
|
||||
if hasattr(torch, 'no_grad'):
|
||||
with torch.no_grad():
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
else:
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
|
||||
if SBN_EPS_MODE == 'clamp':
|
||||
return mean, bias_var.clamp(self.eps) ** -0.5
|
||||
elif SBN_EPS_MODE == 'plus':
|
||||
return mean, (bias_var + self.eps) ** -0.5
|
||||
else:
|
||||
raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
|
||||
|
||||
|
||||
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
||||
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
||||
mini-batch.
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of size
|
||||
`batch_size x num_features [x width]`
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
||||
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
||||
of 3d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C, H, W)`
|
||||
- Output: :math:`(N, C, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
||||
of 4d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
||||
or Spatio-temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x depth x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_sync_batchnorm():
|
||||
import torch.nn as nn
|
||||
|
||||
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
||||
|
||||
nn.BatchNorm1d = SynchronizedBatchNorm1d
|
||||
nn.BatchNorm2d = SynchronizedBatchNorm2d
|
||||
nn.BatchNorm3d = SynchronizedBatchNorm3d
|
||||
|
||||
yield
|
||||
|
||||
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
|
||||
|
||||
|
||||
def convert_model(module):
|
||||
"""Traverse the input module and its child recursively
|
||||
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
||||
to SynchronizedBatchNorm*N*d
|
||||
|
||||
Args:
|
||||
module: the input module needs to be convert to SyncBN model
|
||||
|
||||
Examples:
|
||||
>>> import torch.nn as nn
|
||||
>>> import torchvision
|
||||
>>> # m is a standard pytorch model
|
||||
>>> m = torchvision.models.resnet18(True)
|
||||
>>> m = nn.DataParallel(m)
|
||||
>>> # after convert, m is using SyncBN
|
||||
>>> m = convert_model(m)
|
||||
"""
|
||||
if isinstance(module, torch.nn.DataParallel):
|
||||
mod = module.module
|
||||
mod = convert_model(mod)
|
||||
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
|
||||
return mod
|
||||
|
||||
mod = module
|
||||
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
||||
torch.nn.modules.batchnorm.BatchNorm2d,
|
||||
torch.nn.modules.batchnorm.BatchNorm3d],
|
||||
[SynchronizedBatchNorm1d,
|
||||
SynchronizedBatchNorm2d,
|
||||
SynchronizedBatchNorm3d]):
|
||||
if isinstance(module, pth_module):
|
||||
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
||||
mod.running_mean = module.running_mean
|
||||
mod.running_var = module.running_var
|
||||
if module.affine:
|
||||
mod.weight.data = module.weight.data.clone().detach()
|
||||
mod.bias.data = module.bias.data.clone().detach()
|
||||
|
||||
for name, child in module.named_children():
|
||||
mod.add_module(name, convert_model(child))
|
||||
|
||||
return mod
|
||||
@@ -1,74 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm_reimpl.py
|
||||
# Author : acgtyrant
|
||||
# Date : 11/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
__all__ = ['BatchNorm2dReimpl']
|
||||
|
||||
|
||||
class BatchNorm2dReimpl(nn.Module):
|
||||
"""
|
||||
A re-implementation of batch normalization, used for testing the numerical
|
||||
stability.
|
||||
|
||||
Author: acgtyrant
|
||||
See also:
|
||||
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
||||
"""
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = nn.Parameter(torch.empty(num_features))
|
||||
self.bias = nn.Parameter(torch.empty(num_features))
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_running_stats(self):
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.reset_running_stats()
|
||||
init.uniform_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input_):
|
||||
batchsize, channels, height, width = input_.size()
|
||||
numel = batchsize * height * width
|
||||
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
||||
sum_ = input_.sum(1)
|
||||
sum_of_square = input_.pow(2).sum(1)
|
||||
mean = sum_ / numel
|
||||
sumvar = sum_of_square - sum_ * mean
|
||||
|
||||
self.running_mean = (
|
||||
(1 - self.momentum) * self.running_mean
|
||||
+ self.momentum * mean.detach()
|
||||
)
|
||||
unbias_var = sumvar / (numel - 1)
|
||||
self.running_var = (
|
||||
(1 - self.momentum) * self.running_var
|
||||
+ self.momentum * unbias_var.detach()
|
||||
)
|
||||
|
||||
bias_var = sumvar / numel
|
||||
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
||||
output = (
|
||||
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
||||
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
||||
|
||||
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : comm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import queue
|
||||
import collections
|
||||
import threading
|
||||
|
||||
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
||||
|
||||
|
||||
class FutureResult(object):
|
||||
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
||||
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
self._lock = threading.Lock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
|
||||
def put(self, result):
|
||||
with self._lock:
|
||||
assert self._result is None, 'Previous result has\'t been fetched.'
|
||||
self._result = result
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
if self._result is None:
|
||||
self._cond.wait()
|
||||
|
||||
res = self._result
|
||||
self._result = None
|
||||
return res
|
||||
|
||||
|
||||
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
||||
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
||||
|
||||
|
||||
class SlavePipe(_SlavePipeBase):
|
||||
"""Pipe for master-slave communication."""
|
||||
|
||||
def run_slave(self, msg):
|
||||
self.queue.put((self.identifier, msg))
|
||||
ret = self.result.get()
|
||||
self.queue.put(True)
|
||||
return ret
|
||||
|
||||
|
||||
class SyncMaster(object):
|
||||
"""An abstract `SyncMaster` object.
|
||||
|
||||
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
||||
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
||||
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
||||
and passed to a registered callback.
|
||||
- After receiving the messages, the master device should gather the information and determine to message passed
|
||||
back to each slave devices.
|
||||
"""
|
||||
|
||||
def __init__(self, master_callback):
|
||||
"""
|
||||
|
||||
Args:
|
||||
master_callback: a callback to be invoked after having collected messages from slave devices.
|
||||
"""
|
||||
self._master_callback = master_callback
|
||||
self._queue = queue.Queue()
|
||||
self._registry = collections.OrderedDict()
|
||||
self._activated = False
|
||||
|
||||
def __getstate__(self):
|
||||
return {'master_callback': self._master_callback}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(state['master_callback'])
|
||||
|
||||
def register_slave(self, identifier):
|
||||
"""
|
||||
Register an slave device.
|
||||
|
||||
Args:
|
||||
identifier: an identifier, usually is the device id.
|
||||
|
||||
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
||||
|
||||
"""
|
||||
if self._activated:
|
||||
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
||||
self._activated = False
|
||||
self._registry.clear()
|
||||
future = FutureResult()
|
||||
self._registry[identifier] = _MasterRegistry(future)
|
||||
return SlavePipe(identifier, self._queue, future)
|
||||
|
||||
def run_master(self, master_msg):
|
||||
"""
|
||||
Main entry for the master device in each forward pass.
|
||||
The messages were first collected from each devices (including the master device), and then
|
||||
an callback will be invoked to compute the message to be sent back to each devices
|
||||
(including the master device).
|
||||
|
||||
Args:
|
||||
master_msg: the message that the master want to send to itself. This will be placed as the first
|
||||
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
||||
|
||||
Returns: the message to be sent back to the master device.
|
||||
|
||||
"""
|
||||
self._activated = True
|
||||
|
||||
intermediates = [(0, master_msg)]
|
||||
for i in range(self.nr_slaves):
|
||||
intermediates.append(self._queue.get())
|
||||
|
||||
results = self._master_callback(intermediates)
|
||||
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
||||
|
||||
for i, res in results:
|
||||
if i == 0:
|
||||
continue
|
||||
self._registry[i].result.put(res)
|
||||
|
||||
for i in range(self.nr_slaves):
|
||||
assert self._queue.get() is True
|
||||
|
||||
return results[0][1]
|
||||
|
||||
@property
|
||||
def nr_slaves(self):
|
||||
return len(self._registry)
|
||||
@@ -1,94 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : replicate.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import functools
|
||||
|
||||
from torch.nn.parallel.data_parallel import DataParallel
|
||||
|
||||
__all__ = [
|
||||
'CallbackContext',
|
||||
'execute_replication_callbacks',
|
||||
'DataParallelWithCallback',
|
||||
'patch_replication_callback'
|
||||
]
|
||||
|
||||
|
||||
class CallbackContext(object):
|
||||
pass
|
||||
|
||||
|
||||
def execute_replication_callbacks(modules):
|
||||
"""
|
||||
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
||||
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
||||
(shared among multiple copies of this module on different devices).
|
||||
Through this context, different copies can share some information.
|
||||
|
||||
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
||||
of any slave copies.
|
||||
"""
|
||||
master_copy = modules[0]
|
||||
nr_modules = len(list(master_copy.modules()))
|
||||
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
||||
|
||||
for i, module in enumerate(modules):
|
||||
for j, m in enumerate(module.modules()):
|
||||
if hasattr(m, '__data_parallel_replicate__'):
|
||||
m.__data_parallel_replicate__(ctxs[j], i)
|
||||
|
||||
|
||||
class DataParallelWithCallback(DataParallel):
|
||||
"""
|
||||
Data Parallel with a replication callback.
|
||||
|
||||
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
||||
original `replicate` function.
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
# sync_bn.__data_parallel_replicate__ will be invoked.
|
||||
"""
|
||||
|
||||
def replicate(self, module, device_ids):
|
||||
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
|
||||
def patch_replication_callback(data_parallel):
|
||||
"""
|
||||
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
||||
Useful when you have customized `DataParallel` implementation.
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
||||
> patch_replication_callback(sync_bn)
|
||||
# this is equivalent to
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
"""
|
||||
|
||||
assert isinstance(data_parallel, DataParallel)
|
||||
|
||||
old_replicate = data_parallel.replicate
|
||||
|
||||
@functools.wraps(old_replicate)
|
||||
def new_replicate(module, device_ids):
|
||||
modules = old_replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
data_parallel.replicate = new_replicate
|
||||
@@ -1,29 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : unittest.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
|
||||
class TorchTestCase(unittest.TestCase):
|
||||
def assertTensorClose(self, x, y):
|
||||
adiff = float((x - y).abs().max())
|
||||
if (y == 0).all():
|
||||
rdiff = 'NaN'
|
||||
else:
|
||||
rdiff = float((adiff / y).abs().max())
|
||||
|
||||
message = (
|
||||
'Tensor close check failed\n'
|
||||
'adiff={}\n'
|
||||
'rdiff={}\n'
|
||||
).format(adiff, rdiff)
|
||||
self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import shutil
|
||||
import platform
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
import torchvision as tv
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
from easydict import EasyDict as edict
|
||||
import torchvision.utils as vutils
|
||||
|
||||
|
||||
##### option parsing ######
|
||||
def print_options(config_dict):
|
||||
print("------------ Options -------------")
|
||||
for k, v in sorted(config_dict.items()):
|
||||
print("%s: %s" % (str(k), str(v)))
|
||||
print("-------------- End ----------------")
|
||||
|
||||
|
||||
def save_options(config_dict):
|
||||
from time import gmtime, strftime
|
||||
|
||||
file_dir = os.path.join(config_dict["checkpoint_dir"], config_dict["name"])
|
||||
mkdir_if_not(file_dir)
|
||||
file_name = os.path.join(file_dir, "opt.txt")
|
||||
with open(file_name, "wt") as opt_file:
|
||||
opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n")
|
||||
opt_file.write("------------ Options -------------\n")
|
||||
for k, v in sorted(config_dict.items()):
|
||||
opt_file.write("%s: %s\n" % (str(k), str(v)))
|
||||
opt_file.write("-------------- End ----------------\n")
|
||||
|
||||
|
||||
def config_parse(config_file, options, save=True):
|
||||
with open(config_file, "r") as stream:
|
||||
config_dict = yaml.safe_load(stream)
|
||||
config = edict(config_dict)
|
||||
|
||||
for option_key, option_value in vars(options).items():
|
||||
config_dict[option_key] = option_value
|
||||
config[option_key] = option_value
|
||||
|
||||
if config.debug_mode:
|
||||
config_dict["num_workers"] = 0
|
||||
config.num_workers = 0
|
||||
config.batch_size = 2
|
||||
if isinstance(config.gpu_ids, str):
|
||||
config.gpu_ids = [int(x) for x in config.gpu_ids.split(",")][0]
|
||||
|
||||
print_options(config_dict)
|
||||
if save:
|
||||
save_options(config_dict)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
###### utility ######
|
||||
def to_np(x):
|
||||
return x.cpu().numpy()
|
||||
|
||||
|
||||
def prepare_device(use_gpu, gpu_ids):
|
||||
if torch.cuda.is_available() and use_gpu:
|
||||
cudnn.benchmark = True
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
if isinstance(gpu_ids, str):
|
||||
gpu_ids = [int(x) for x in gpu_ids.split(",")]
|
||||
torch.cuda.set_device(gpu_ids[0])
|
||||
device = torch.device("cuda:" + str(gpu_ids[0]))
|
||||
else:
|
||||
torch.cuda.set_device(gpu_ids)
|
||||
device = torch.device("cuda:" + str(gpu_ids))
|
||||
print("running on GPU {}".format(gpu_ids))
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("running on CPU")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
###### file system ######
|
||||
def get_dir_size(start_path="."):
|
||||
total_size = 0
|
||||
for dirpath, dirnames, filenames in os.walk(start_path):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
total_size += os.path.getsize(fp)
|
||||
return total_size
|
||||
|
||||
|
||||
def mkdir_if_not(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
##### System related ######
|
||||
class Timer:
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
elapse = time.time() - self.start_time
|
||||
print(self.msg % elapse)
|
||||
|
||||
|
||||
###### interactive ######
|
||||
def get_size(start_path="."):
|
||||
total_size = 0
|
||||
for dirpath, dirnames, filenames in os.walk(start_path):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
total_size += os.path.getsize(fp)
|
||||
return total_size
|
||||
|
||||
|
||||
def clean_tensorboard(directory):
|
||||
tensorboard_list = os.listdir(directory)
|
||||
SIZE_THRESH = 100000
|
||||
for tensorboard in tensorboard_list:
|
||||
tensorboard = os.path.join(directory, tensorboard)
|
||||
if get_size(tensorboard) < SIZE_THRESH:
|
||||
print("deleting the empty tensorboard: ", tensorboard)
|
||||
#
|
||||
if os.path.isdir(tensorboard):
|
||||
shutil.rmtree(tensorboard)
|
||||
else:
|
||||
os.remove(tensorboard)
|
||||
|
||||
|
||||
def prepare_tensorboard(config, experiment_name=datetime.now().strftime("%Y-%m-%d %H-%M-%S")):
|
||||
tensorboard_directory = os.path.join(config.checkpoint_dir, config.name, "tensorboard_logs")
|
||||
mkdir_if_not(tensorboard_directory)
|
||||
clean_tensorboard(tensorboard_directory)
|
||||
tb_writer = SummaryWriter(os.path.join(tensorboard_directory, experiment_name), flush_secs=10)
|
||||
|
||||
# try:
|
||||
# shutil.copy('outputs/opt.txt', tensorboard_directory)
|
||||
# except:
|
||||
# print('cannot find file opt.txt')
|
||||
return tb_writer
|
||||
|
||||
|
||||
def tb_loss_logger(tb_writer, iter_index, loss_logger):
|
||||
for tag, value in loss_logger.items():
|
||||
tb_writer.add_scalar(tag, scalar_value=value.item(), global_step=iter_index)
|
||||
|
||||
|
||||
def tb_image_logger(tb_writer, iter_index, images_info, config):
|
||||
### Save and write the output into the tensorboard
|
||||
tb_logger_path = os.path.join(config.output_dir, config.name, config.train_mode)
|
||||
mkdir_if_not(tb_logger_path)
|
||||
for tag, image in images_info.items():
|
||||
if tag == "test_image_prediction" or tag == "image_prediction":
|
||||
continue
|
||||
image = tv.utils.make_grid(image.cpu())
|
||||
image = torch.clamp(image, 0, 1)
|
||||
tb_writer.add_image(tag, img_tensor=image, global_step=iter_index)
|
||||
tv.transforms.functional.to_pil_image(image).save(
|
||||
os.path.join(tb_logger_path, "{:06d}_{}.jpg".format(iter_index, tag))
|
||||
)
|
||||
|
||||
|
||||
def tb_image_logger_test(epoch, iter, images_info, config):
|
||||
|
||||
url = os.path.join(config.output_dir, config.name, config.train_mode, "val_" + str(epoch))
|
||||
if not os.path.exists(url):
|
||||
os.makedirs(url)
|
||||
scratch_img = images_info["test_scratch_image"].data.cpu()
|
||||
if config.norm_input:
|
||||
scratch_img = (scratch_img + 1.0) / 2.0
|
||||
scratch_img = torch.clamp(scratch_img, 0, 1)
|
||||
gt_mask = images_info["test_mask_image"].data.cpu()
|
||||
predict_mask = images_info["test_scratch_prediction"].data.cpu()
|
||||
|
||||
predict_hard_mask = (predict_mask.data.cpu() >= 0.5).float()
|
||||
|
||||
imgs = torch.cat((scratch_img, predict_hard_mask, gt_mask), 0)
|
||||
img_grid = vutils.save_image(
|
||||
imgs, os.path.join(url, str(iter) + ".jpg"), nrow=len(scratch_img), padding=0, normalize=True
|
||||
)
|
||||
|
||||
|
||||
def imshow(input_image, title=None, to_numpy=False):
|
||||
inp = input_image
|
||||
if to_numpy or type(input_image) is torch.Tensor:
|
||||
inp = input_image.numpy()
|
||||
|
||||
fig = plt.figure()
|
||||
if inp.ndim == 2:
|
||||
fig = plt.imshow(inp, cmap="gray", clim=[0, 255])
|
||||
else:
|
||||
fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8))
|
||||
plt.axis("off")
|
||||
fig.axes.get_xaxis().set_visible(False)
|
||||
fig.axes.get_yaxis().set_visible(False)
|
||||
plt.title(title)
|
||||
|
||||
|
||||
###### vgg preprocessing ######
|
||||
def vgg_preprocess(tensor):
|
||||
# input is RGB tensor which ranges in [0,1]
|
||||
# output is BGR tensor which ranges in [0,255]
|
||||
tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
|
||||
# tensor_bgr = tensor[:, [2, 1, 0], ...]
|
||||
tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(
|
||||
1, 3, 1, 1
|
||||
)
|
||||
tensor_rst = tensor_bgr_ml * 255
|
||||
return tensor_rst
|
||||
|
||||
|
||||
def torch_vgg_preprocess(tensor):
|
||||
# pytorch version normalization
|
||||
# note that both input and output are RGB tensors;
|
||||
# input and output ranges in [0,1]
|
||||
# normalize the tensor with mean and variance
|
||||
tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1)
|
||||
tensor_mc_norm = tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1)
|
||||
return tensor_mc_norm
|
||||
|
||||
|
||||
def network_gradient(net, gradient_on=True):
|
||||
if gradient_on:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = False
|
||||
return net
|
||||
@@ -1,204 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
import torch.nn as nn
|
||||
from . import networks
|
||||
|
||||
|
||||
class Mapping_Model_with_mask(nn.Module):
|
||||
def __init__(
|
||||
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
|
||||
):
|
||||
super(Mapping_Model_with_mask, self).__init__()
|
||||
|
||||
norm_layer = networks.get_norm_layer(norm_type=norm)
|
||||
activation = nn.ReLU(True)
|
||||
model = []
|
||||
|
||||
tmp_nc = 64
|
||||
n_up = 4
|
||||
|
||||
for i in range(n_up):
|
||||
ic = min(tmp_nc * (2**i), mc)
|
||||
oc = min(tmp_nc * (2 ** (i + 1)), mc)
|
||||
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
||||
|
||||
self.before_NL = nn.Sequential(*model)
|
||||
|
||||
if opt.NL_res:
|
||||
self.NL = networks.NonLocalBlock2D_with_mask_Res(
|
||||
mc,
|
||||
mc,
|
||||
opt.NL_fusion_method,
|
||||
opt.correlation_renormalize,
|
||||
opt.softmax_temperature,
|
||||
opt.use_self,
|
||||
opt.cosin_similarity,
|
||||
)
|
||||
|
||||
print("using NL + Res...")
|
||||
|
||||
model = []
|
||||
for i in range(n_blocks):
|
||||
model += [
|
||||
networks.ResnetBlock(
|
||||
mc,
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
dilation=opt.mapping_net_dilation,
|
||||
)
|
||||
]
|
||||
|
||||
for i in range(n_up - 1):
|
||||
ic = min(64 * (2 ** (4 - i)), mc)
|
||||
oc = min(64 * (2 ** (3 - i)), mc)
|
||||
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
||||
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
|
||||
if opt.feat_dim > 0 and opt.feat_dim < 64:
|
||||
model += [
|
||||
norm_layer(tmp_nc),
|
||||
activation,
|
||||
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
|
||||
]
|
||||
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
|
||||
self.after_NL = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input, mask):
|
||||
x1 = self.before_NL(input)
|
||||
del input
|
||||
x2 = self.NL(x1, mask)
|
||||
del x1, mask
|
||||
x3 = self.after_NL(x2)
|
||||
del x2
|
||||
|
||||
return x3
|
||||
|
||||
|
||||
class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention
|
||||
def __init__(
|
||||
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
|
||||
):
|
||||
super(Mapping_Model_with_mask_2, self).__init__()
|
||||
|
||||
norm_layer = networks.get_norm_layer(norm_type=norm)
|
||||
activation = nn.ReLU(True)
|
||||
model = []
|
||||
|
||||
tmp_nc = 64
|
||||
n_up = 4
|
||||
|
||||
for i in range(n_up):
|
||||
ic = min(tmp_nc * (2**i), mc)
|
||||
oc = min(tmp_nc * (2 ** (i + 1)), mc)
|
||||
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
||||
|
||||
for i in range(2):
|
||||
model += [
|
||||
networks.ResnetBlock(
|
||||
mc,
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
dilation=opt.mapping_net_dilation,
|
||||
)
|
||||
]
|
||||
|
||||
print("using multi-scale patch attention, conv combine + mask input...")
|
||||
|
||||
self.before_NL = nn.Sequential(*model)
|
||||
|
||||
if opt.mapping_exp == 1:
|
||||
self.NL_scale_1 = networks.Patch_Attention_4(mc, mc, 8)
|
||||
|
||||
model = []
|
||||
for i in range(2):
|
||||
model += [
|
||||
networks.ResnetBlock(
|
||||
mc,
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
dilation=opt.mapping_net_dilation,
|
||||
)
|
||||
]
|
||||
|
||||
self.res_block_1 = nn.Sequential(*model)
|
||||
|
||||
if opt.mapping_exp == 1:
|
||||
self.NL_scale_2 = networks.Patch_Attention_4(mc, mc, 4)
|
||||
|
||||
model = []
|
||||
for i in range(2):
|
||||
model += [
|
||||
networks.ResnetBlock(
|
||||
mc,
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
dilation=opt.mapping_net_dilation,
|
||||
)
|
||||
]
|
||||
|
||||
self.res_block_2 = nn.Sequential(*model)
|
||||
|
||||
if opt.mapping_exp == 1:
|
||||
self.NL_scale_3 = networks.Patch_Attention_4(mc, mc, 2)
|
||||
# self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2)
|
||||
|
||||
model = []
|
||||
for i in range(2):
|
||||
model += [
|
||||
networks.ResnetBlock(
|
||||
mc,
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
dilation=opt.mapping_net_dilation,
|
||||
)
|
||||
]
|
||||
|
||||
for i in range(n_up - 1):
|
||||
ic = min(64 * (2 ** (4 - i)), mc)
|
||||
oc = min(64 * (2 ** (3 - i)), mc)
|
||||
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
||||
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
|
||||
if opt.feat_dim > 0 and opt.feat_dim < 64:
|
||||
model += [
|
||||
norm_layer(tmp_nc),
|
||||
activation,
|
||||
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
|
||||
]
|
||||
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
|
||||
self.after_NL = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input, mask):
|
||||
x1 = self.before_NL(input)
|
||||
x2 = self.NL_scale_1(x1, mask)
|
||||
x3 = self.res_block_1(x2)
|
||||
x4 = self.NL_scale_2(x3, mask)
|
||||
x5 = self.res_block_2(x4)
|
||||
x6 = self.NL_scale_3(x5, mask)
|
||||
x7 = self.after_NL(x6)
|
||||
return x7
|
||||
|
||||
def inference_forward(self, input, mask):
|
||||
x1 = self.before_NL(input)
|
||||
del input
|
||||
x2 = self.NL_scale_1.inference_forward(x1, mask)
|
||||
del x1
|
||||
x3 = self.res_block_1(x2)
|
||||
del x2
|
||||
x4 = self.NL_scale_2.inference_forward(x3, mask)
|
||||
del x3
|
||||
x5 = self.res_block_2(x4)
|
||||
del x4
|
||||
x6 = self.NL_scale_3.inference_forward(x5, mask)
|
||||
del x5
|
||||
x7 = self.after_NL(x6)
|
||||
del x6
|
||||
return x7
|
||||
@@ -1,122 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import sys
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def name(self):
|
||||
return "BaseModel"
|
||||
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.gpu_ids = opt.gpu_ids
|
||||
self.isTrain = opt.isTrain
|
||||
self.Tensor = torch.cuda.FloatTensor if (torch.cuda.is_available() and self.gpu_ids) else torch.Tensor
|
||||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
|
||||
def set_input(self, input):
|
||||
self.input = input
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
# used in test time, no backprop
|
||||
def test(self):
|
||||
pass
|
||||
|
||||
def get_image_paths(self):
|
||||
pass
|
||||
|
||||
def optimize_parameters(self):
|
||||
pass
|
||||
|
||||
def get_current_visuals(self):
|
||||
return self.input
|
||||
|
||||
def get_current_errors(self):
|
||||
return {}
|
||||
|
||||
def save(self, label):
|
||||
pass
|
||||
|
||||
# helper saving function that can be used by subclasses
|
||||
def save_network(self, network, network_label, epoch_label, gpu_ids):
|
||||
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
torch.save(network.cpu().state_dict(), save_path)
|
||||
if len(gpu_ids) and torch.cuda.is_available():
|
||||
network.cuda()
|
||||
|
||||
def save_optimizer(self, optimizer, optimizer_label, epoch_label):
|
||||
save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
torch.save(optimizer.state_dict(), save_path)
|
||||
|
||||
def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""):
|
||||
save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label)
|
||||
if not save_dir:
|
||||
save_dir = self.save_dir
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
|
||||
if not os.path.isfile(save_path):
|
||||
print("%s not exists yet!" % save_path)
|
||||
else:
|
||||
optimizer.load_state_dict(torch.load(save_path))
|
||||
|
||||
# helper loading function that can be used by subclasses
|
||||
def load_network(self, network, network_label, epoch_label, save_dir=""):
|
||||
save_filename = "%s_net_%s.pth" % (epoch_label, network_label)
|
||||
if not save_dir:
|
||||
save_dir = self.save_dir
|
||||
|
||||
# print(save_dir)
|
||||
# print(self.save_dir)
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
if not os.path.isfile(save_path):
|
||||
print("%s not exists yet!" % save_path)
|
||||
# if network_label == 'G':
|
||||
# raise('Generator must exist!')
|
||||
else:
|
||||
# network.load_state_dict(torch.load(save_path))
|
||||
try:
|
||||
# print(save_path)
|
||||
network.load_state_dict(torch.load(save_path))
|
||||
except:
|
||||
pretrained_dict = torch.load(save_path)
|
||||
model_dict = network.state_dict()
|
||||
try:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
network.load_state_dict(pretrained_dict)
|
||||
# if self.opt.verbose:
|
||||
print(
|
||||
"Pretrained network %s has excessive layers; Only loading layers that are used"
|
||||
% network_label
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"Pretrained network %s has fewer layers; The following are not initialized:"
|
||||
% network_label
|
||||
)
|
||||
for k, v in pretrained_dict.items():
|
||||
if v.size() == model_dict[k].size():
|
||||
model_dict[k] = v
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
not_initialized = set()
|
||||
else:
|
||||
from sets import Set
|
||||
|
||||
not_initialized = Set()
|
||||
|
||||
for k, v in model_dict.items():
|
||||
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
|
||||
not_initialized.add(k.split(".")[0])
|
||||
|
||||
print(sorted(not_initialized))
|
||||
network.load_state_dict(model_dict)
|
||||
|
||||
def update_learning_rate():
|
||||
pass
|
||||
@@ -1,453 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
from .NonLocal_feature_mapping_model import *
|
||||
from ..util.image_pool import ImagePool
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
class Mapping_Model(nn.Module):
|
||||
def __init__(
|
||||
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
|
||||
):
|
||||
super(Mapping_Model, self).__init__()
|
||||
|
||||
norm_layer = networks.get_norm_layer(norm_type=norm)
|
||||
activation = nn.ReLU(True)
|
||||
model = []
|
||||
tmp_nc = 64
|
||||
n_up = 4
|
||||
|
||||
# print("Mapping: You are using the mapping model without global restoration.")
|
||||
|
||||
for i in range(n_up):
|
||||
ic = min(tmp_nc * (2**i), mc)
|
||||
oc = min(tmp_nc * (2 ** (i + 1)), mc)
|
||||
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
||||
for i in range(n_blocks):
|
||||
model += [
|
||||
networks.ResnetBlock(
|
||||
mc,
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
dilation=opt.mapping_net_dilation,
|
||||
)
|
||||
]
|
||||
|
||||
for i in range(n_up - 1):
|
||||
ic = min(64 * (2 ** (4 - i)), mc)
|
||||
oc = min(64 * (2 ** (3 - i)), mc)
|
||||
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
||||
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
|
||||
if opt.feat_dim > 0 and opt.feat_dim < 64:
|
||||
model += [
|
||||
norm_layer(tmp_nc),
|
||||
activation,
|
||||
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
|
||||
]
|
||||
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
class Pix2PixHDModel_Mapping(BaseModel):
|
||||
def name(self):
|
||||
return "Pix2PixHDModel_Mapping"
|
||||
|
||||
def init_loss_filter(
|
||||
self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2
|
||||
):
|
||||
flags = (
|
||||
True,
|
||||
True,
|
||||
use_gan_feat_loss,
|
||||
use_vgg_loss,
|
||||
True,
|
||||
True,
|
||||
use_smooth_l1,
|
||||
stage_1_feat_l2,
|
||||
)
|
||||
|
||||
def loss_filter(
|
||||
g_feat_l2,
|
||||
g_gan,
|
||||
g_gan_feat,
|
||||
g_vgg,
|
||||
d_real,
|
||||
d_fake,
|
||||
smooth_l1,
|
||||
stage_1_feat_l2,
|
||||
):
|
||||
return [
|
||||
l
|
||||
for (l, f) in zip(
|
||||
(
|
||||
g_feat_l2,
|
||||
g_gan,
|
||||
g_gan_feat,
|
||||
g_vgg,
|
||||
d_real,
|
||||
d_fake,
|
||||
smooth_l1,
|
||||
stage_1_feat_l2,
|
||||
),
|
||||
flags,
|
||||
)
|
||||
if f
|
||||
]
|
||||
|
||||
return loss_filter
|
||||
|
||||
def initialize(self, opt):
|
||||
BaseModel.initialize(self, opt)
|
||||
if opt.resize_or_crop != "none" or not opt.isTrain:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.isTrain = opt.isTrain
|
||||
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
|
||||
|
||||
##### define networks
|
||||
# Generator network
|
||||
netG_input_nc = input_nc
|
||||
self.netG_A = networks.GlobalGenerator_DCDCv2(
|
||||
netG_input_nc,
|
||||
opt.output_nc,
|
||||
opt.ngf,
|
||||
opt.k_size,
|
||||
opt.n_downsample_global,
|
||||
networks.get_norm_layer(norm_type=opt.norm),
|
||||
opt=opt,
|
||||
)
|
||||
self.netG_B = networks.GlobalGenerator_DCDCv2(
|
||||
netG_input_nc,
|
||||
opt.output_nc,
|
||||
opt.ngf,
|
||||
opt.k_size,
|
||||
opt.n_downsample_global,
|
||||
networks.get_norm_layer(norm_type=opt.norm),
|
||||
opt=opt,
|
||||
)
|
||||
|
||||
if opt.non_local == "Setting_42" or opt.NL_use_mask:
|
||||
if opt.mapping_exp == 1:
|
||||
self.mapping_net = Mapping_Model_with_mask_2(
|
||||
min(opt.ngf * 2**opt.n_downsample_global, opt.mc),
|
||||
opt.map_mc,
|
||||
n_blocks=opt.mapping_n_block,
|
||||
opt=opt,
|
||||
)
|
||||
else:
|
||||
self.mapping_net = Mapping_Model_with_mask(
|
||||
min(opt.ngf * 2**opt.n_downsample_global, opt.mc),
|
||||
opt.map_mc,
|
||||
n_blocks=opt.mapping_n_block,
|
||||
opt=opt,
|
||||
)
|
||||
else:
|
||||
self.mapping_net = Mapping_Model(
|
||||
min(opt.ngf * 2**opt.n_downsample_global, opt.mc),
|
||||
opt.map_mc,
|
||||
n_blocks=opt.mapping_n_block,
|
||||
opt=opt,
|
||||
)
|
||||
|
||||
self.mapping_net.apply(networks.weights_init)
|
||||
|
||||
if opt.load_pretrain != "":
|
||||
self.load_network(
|
||||
self.mapping_net, "mapping_net", opt.which_epoch, opt.load_pretrain
|
||||
)
|
||||
|
||||
if not opt.no_load_VAE:
|
||||
|
||||
self.load_network(
|
||||
self.netG_A, "G", opt.use_vae_which_epoch, opt.load_pretrainA
|
||||
)
|
||||
self.load_network(
|
||||
self.netG_B, "G", opt.use_vae_which_epoch, opt.load_pretrainB
|
||||
)
|
||||
for param in self.netG_A.parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.netG_B.parameters():
|
||||
param.requires_grad = False
|
||||
self.netG_A.eval()
|
||||
self.netG_B.eval()
|
||||
|
||||
if torch.cuda.is_available() and opt.gpu_ids:
|
||||
self.netG_A.cuda(opt.gpu_ids[0])
|
||||
self.netG_B.cuda(opt.gpu_ids[0])
|
||||
self.mapping_net.cuda(opt.gpu_ids[0])
|
||||
|
||||
if not self.isTrain:
|
||||
self.load_network(self.mapping_net, "mapping_net", opt.which_epoch)
|
||||
|
||||
# Discriminator network
|
||||
if self.isTrain:
|
||||
use_sigmoid = opt.no_lsgan
|
||||
netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc
|
||||
if not opt.no_instance:
|
||||
netD_input_nc += 1
|
||||
|
||||
self.netD = networks.define_D(
|
||||
netD_input_nc,
|
||||
opt.ndf,
|
||||
opt.n_layers_D,
|
||||
opt,
|
||||
opt.norm,
|
||||
use_sigmoid,
|
||||
opt.num_D,
|
||||
not opt.no_ganFeat_loss,
|
||||
gpu_ids=self.gpu_ids,
|
||||
)
|
||||
|
||||
# set loss functions and optimizers
|
||||
if self.isTrain:
|
||||
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
|
||||
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
|
||||
self.fake_pool = ImagePool(opt.pool_size)
|
||||
self.old_lr = opt.lr
|
||||
|
||||
# define loss functions
|
||||
self.loss_filter = self.init_loss_filter(
|
||||
not opt.no_ganFeat_loss,
|
||||
not opt.no_vgg_loss,
|
||||
opt.Smooth_L1,
|
||||
opt.use_two_stage_mapping,
|
||||
)
|
||||
|
||||
self.criterionGAN = networks.GANLoss(
|
||||
use_lsgan=not opt.no_lsgan, tensor=self.Tensor
|
||||
)
|
||||
|
||||
self.criterionFeat = torch.nn.L1Loss()
|
||||
self.criterionFeat_feat = (
|
||||
torch.nn.L1Loss() if opt.use_l1_feat else torch.nn.MSELoss()
|
||||
)
|
||||
|
||||
if self.opt.image_L1:
|
||||
self.criterionImage = torch.nn.L1Loss()
|
||||
else:
|
||||
self.criterionImage = torch.nn.SmoothL1Loss()
|
||||
|
||||
print(self.criterionFeat_feat)
|
||||
if not opt.no_vgg_loss:
|
||||
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)
|
||||
|
||||
# Names so we can breakout loss
|
||||
self.loss_names = self.loss_filter(
|
||||
"G_Feat_L2",
|
||||
"G_GAN",
|
||||
"G_GAN_Feat",
|
||||
"G_VGG",
|
||||
"D_real",
|
||||
"D_fake",
|
||||
"Smooth_L1",
|
||||
"G_Feat_L2_Stage_1",
|
||||
)
|
||||
|
||||
# initialize optimizers
|
||||
# optimizer G
|
||||
|
||||
if opt.no_TTUR:
|
||||
beta1, beta2 = opt.beta1, 0.999
|
||||
G_lr, D_lr = opt.lr, opt.lr
|
||||
else:
|
||||
beta1, beta2 = 0, 0.9
|
||||
G_lr, D_lr = opt.lr / 2, opt.lr * 2
|
||||
|
||||
if not opt.no_load_VAE:
|
||||
params = list(self.mapping_net.parameters())
|
||||
self.optimizer_mapping = torch.optim.Adam(
|
||||
params, lr=G_lr, betas=(beta1, beta2)
|
||||
)
|
||||
|
||||
# optimizer D
|
||||
params = list(self.netD.parameters())
|
||||
self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2))
|
||||
|
||||
print("---------- Optimizers initialized -------------")
|
||||
|
||||
def encode_input(
|
||||
self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False
|
||||
):
|
||||
if self.opt.label_nc == 0:
|
||||
input_label = label_map.data.cuda()
|
||||
else:
|
||||
# create one-hot vector for label map
|
||||
size = label_map.size()
|
||||
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
|
||||
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
|
||||
if self.opt.data_type == 16:
|
||||
input_label = input_label.half()
|
||||
|
||||
# get edges from instance map
|
||||
if not self.opt.no_instance:
|
||||
inst_map = inst_map.data.cuda()
|
||||
edge_map = self.get_edges(inst_map)
|
||||
input_label = torch.cat((input_label, edge_map), dim=1)
|
||||
input_label = Variable(input_label, volatile=infer)
|
||||
|
||||
# real images for training
|
||||
if real_image is not None:
|
||||
real_image = Variable(real_image.data.cuda())
|
||||
|
||||
return input_label, inst_map, real_image, feat_map
|
||||
|
||||
def discriminate(self, input_label, test_image, use_pool=False):
|
||||
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
|
||||
if use_pool:
|
||||
fake_query = self.fake_pool.query(input_concat)
|
||||
return self.netD.forward(fake_query)
|
||||
else:
|
||||
return self.netD.forward(input_concat)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
label,
|
||||
inst,
|
||||
image,
|
||||
feat,
|
||||
pair=True,
|
||||
infer=False,
|
||||
last_label=None,
|
||||
last_image=None,
|
||||
):
|
||||
# Encode Inputs
|
||||
input_label, inst_map, real_image, feat_map = self.encode_input(
|
||||
label, inst, image, feat
|
||||
)
|
||||
|
||||
# Fake Generation
|
||||
input_concat = input_label
|
||||
|
||||
label_feat = self.netG_A.forward(input_concat, flow="enc")
|
||||
# print('label:')
|
||||
# print(label_feat.min(), label_feat.max(), label_feat.mean())
|
||||
# label_feat = label_feat / 16.0
|
||||
|
||||
if self.opt.NL_use_mask:
|
||||
label_feat_map = self.mapping_net(label_feat.detach(), inst)
|
||||
else:
|
||||
label_feat_map = self.mapping_net(label_feat.detach())
|
||||
|
||||
fake_image = self.netG_B.forward(label_feat_map, flow="dec")
|
||||
image_feat = self.netG_B.forward(real_image, flow="enc")
|
||||
|
||||
loss_feat_l2_stage_1 = 0
|
||||
loss_feat_l2 = (
|
||||
self.criterionFeat_feat(label_feat_map, image_feat.data) * self.opt.l2_feat
|
||||
)
|
||||
|
||||
if self.opt.feat_gan:
|
||||
# Fake Detection and Loss
|
||||
pred_fake_pool = self.discriminate(
|
||||
label_feat.detach(), label_feat_map, use_pool=True
|
||||
)
|
||||
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
|
||||
|
||||
# Real Detection and Loss
|
||||
pred_real = self.discriminate(label_feat.detach(), image_feat)
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
|
||||
# GAN loss (Fake Passability Loss)
|
||||
pred_fake = self.netD.forward(
|
||||
torch.cat((label_feat.detach(), label_feat_map), dim=1)
|
||||
)
|
||||
loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
else:
|
||||
# Fake Detection and Loss
|
||||
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
|
||||
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
|
||||
|
||||
# Real Detection and Loss
|
||||
if pair:
|
||||
pred_real = self.discriminate(input_label, real_image)
|
||||
else:
|
||||
pred_real = self.discriminate(last_label, last_image)
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
|
||||
# GAN loss (Fake Passability Loss)
|
||||
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
|
||||
loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
|
||||
# GAN feature matching loss
|
||||
loss_G_GAN_Feat = 0
|
||||
if not self.opt.no_ganFeat_loss and pair:
|
||||
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
|
||||
D_weights = 1.0 / self.opt.num_D
|
||||
for i in range(self.opt.num_D):
|
||||
for j in range(len(pred_fake[i]) - 1):
|
||||
tmp = (
|
||||
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
|
||||
* self.opt.lambda_feat
|
||||
)
|
||||
loss_G_GAN_Feat += D_weights * feat_weights * tmp
|
||||
else:
|
||||
loss_G_GAN_Feat = torch.zeros(1).to(label.device)
|
||||
|
||||
# VGG feature matching loss
|
||||
loss_G_VGG = 0
|
||||
if not self.opt.no_vgg_loss:
|
||||
loss_G_VGG = (
|
||||
self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
|
||||
if pair
|
||||
else torch.zeros(1).to(label.device)
|
||||
)
|
||||
|
||||
smooth_l1_loss = 0
|
||||
if self.opt.Smooth_L1:
|
||||
smooth_l1_loss = (
|
||||
self.criterionImage(fake_image, real_image) * self.opt.L1_weight
|
||||
)
|
||||
|
||||
return [
|
||||
self.loss_filter(
|
||||
loss_feat_l2,
|
||||
loss_G_GAN,
|
||||
loss_G_GAN_Feat,
|
||||
loss_G_VGG,
|
||||
loss_D_real,
|
||||
loss_D_fake,
|
||||
smooth_l1_loss,
|
||||
loss_feat_l2_stage_1,
|
||||
),
|
||||
None if not infer else fake_image,
|
||||
]
|
||||
|
||||
def inference(self, label, inst):
|
||||
|
||||
use_gpu = torch.cuda.is_available() and len(self.opt.gpu_ids) > 0
|
||||
if use_gpu:
|
||||
input_concat = label.data.cuda()
|
||||
inst_data = inst.cuda()
|
||||
else:
|
||||
input_concat = label.data
|
||||
inst_data = inst
|
||||
|
||||
label_feat = self.netG_A.forward(input_concat, flow="enc")
|
||||
|
||||
if self.opt.NL_use_mask:
|
||||
if self.opt.inference_optimize:
|
||||
label_feat_map = self.mapping_net.inference_forward(
|
||||
label_feat.detach(), inst_data
|
||||
)
|
||||
else:
|
||||
label_feat_map = self.mapping_net(label_feat.detach(), inst_data)
|
||||
else:
|
||||
label_feat_map = self.mapping_net(label_feat.detach())
|
||||
|
||||
fake_image = self.netG_B.forward(label_feat_map, flow="dec")
|
||||
return fake_image
|
||||
|
||||
|
||||
class InferenceModel(Pix2PixHDModel_Mapping):
|
||||
def forward(self, label, inst):
|
||||
return self.inference(label, inst)
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
assert opt.model == "pix2pixHD"
|
||||
|
||||
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
|
||||
|
||||
if opt.isTrain:
|
||||
model = Pix2PixHDModel()
|
||||
else:
|
||||
model = InferenceModel()
|
||||
|
||||
model.initialize(opt)
|
||||
if opt.verbose:
|
||||
print("model [%s] was created" % (model.name()))
|
||||
|
||||
assert not opt.isTrain
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_da_model(opt):
|
||||
assert opt.model == "pix2pixHD"
|
||||
|
||||
from .pix2pixHD_model_DA import Pix2PixHDModel, InferenceModel
|
||||
|
||||
if opt.isTrain:
|
||||
model = Pix2PixHDModel()
|
||||
else:
|
||||
model = InferenceModel()
|
||||
|
||||
model.initialize(opt)
|
||||
if opt.verbose:
|
||||
print("model [%s] was created" % (model.name()))
|
||||
|
||||
assert not opt.isTrain
|
||||
|
||||
return model
|
||||
@@ -1,875 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
# from util.util import SwitchNorm2d
|
||||
import torch.nn.functional as F
|
||||
|
||||
###############################################################################
|
||||
# Functions
|
||||
###############################################################################
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
elif classname.find("BatchNorm2d") != -1:
|
||||
m.weight.data.normal_(1.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
def get_norm_layer(norm_type="instance"):
|
||||
if norm_type == "batch":
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||||
elif norm_type == "instance":
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
||||
elif norm_type == "spectral":
|
||||
norm_layer = spectral_norm()
|
||||
elif norm_type == "SwitchNorm":
|
||||
norm_layer = SwitchNorm2d
|
||||
else:
|
||||
raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
|
||||
return norm_layer
|
||||
|
||||
|
||||
def print_network(net):
|
||||
if isinstance(net, list):
|
||||
net = net[0]
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
print(net)
|
||||
print("Total number of parameters: %d" % num_params)
|
||||
|
||||
|
||||
def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
|
||||
n_blocks_local=3, norm='instance', gpu_ids=[], opt=None):
|
||||
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
if netG == 'global':
|
||||
# if opt.self_gen:
|
||||
if opt.use_v2:
|
||||
netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt)
|
||||
else:
|
||||
netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt)
|
||||
else:
|
||||
raise('generator not implemented!')
|
||||
print(netG)
|
||||
if len(gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
netG.cuda(gpu_ids[0])
|
||||
netG.apply(weights_init)
|
||||
return netG
|
||||
|
||||
|
||||
def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
|
||||
print(netD)
|
||||
if len(gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
netD.cuda(gpu_ids[0])
|
||||
netD.apply(weights_init)
|
||||
return netD
|
||||
|
||||
|
||||
|
||||
class GlobalGenerator_DCDCv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_nc,
|
||||
output_nc,
|
||||
ngf=64,
|
||||
k_size=3,
|
||||
n_downsampling=8,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
padding_type="reflect",
|
||||
opt=None,
|
||||
):
|
||||
super(GlobalGenerator_DCDCv2, self).__init__()
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
model = [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0),
|
||||
norm_layer(ngf),
|
||||
activation,
|
||||
]
|
||||
### downsample
|
||||
for i in range(opt.start_r):
|
||||
mult = 2 ** i
|
||||
model += [
|
||||
nn.Conv2d(
|
||||
min(ngf * mult, opt.mc),
|
||||
min(ngf * mult * 2, opt.mc),
|
||||
kernel_size=k_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
norm_layer(min(ngf * mult * 2, opt.mc)),
|
||||
activation,
|
||||
]
|
||||
for i in range(opt.start_r, n_downsampling - 1):
|
||||
mult = 2 ** i
|
||||
model += [
|
||||
nn.Conv2d(
|
||||
min(ngf * mult, opt.mc),
|
||||
min(ngf * mult * 2, opt.mc),
|
||||
kernel_size=k_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
norm_layer(min(ngf * mult * 2, opt.mc)),
|
||||
activation,
|
||||
]
|
||||
model += [
|
||||
ResnetBlock(
|
||||
min(ngf * mult * 2, opt.mc),
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
)
|
||||
]
|
||||
model += [
|
||||
ResnetBlock(
|
||||
min(ngf * mult * 2, opt.mc),
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
)
|
||||
]
|
||||
mult = 2 ** (n_downsampling - 1)
|
||||
|
||||
if opt.spatio_size == 32:
|
||||
model += [
|
||||
nn.Conv2d(
|
||||
min(ngf * mult, opt.mc),
|
||||
min(ngf * mult * 2, opt.mc),
|
||||
kernel_size=k_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
norm_layer(min(ngf * mult * 2, opt.mc)),
|
||||
activation,
|
||||
]
|
||||
if opt.spatio_size == 64:
|
||||
model += [
|
||||
ResnetBlock(
|
||||
min(ngf * mult * 2, opt.mc),
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
)
|
||||
]
|
||||
model += [
|
||||
ResnetBlock(
|
||||
min(ngf * mult * 2, opt.mc),
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
)
|
||||
]
|
||||
# model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)]
|
||||
if opt.feat_dim > 0:
|
||||
model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)]
|
||||
self.encoder = nn.Sequential(*model)
|
||||
|
||||
# decode
|
||||
model = []
|
||||
if opt.feat_dim > 0:
|
||||
model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)]
|
||||
# model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)]
|
||||
o_pad = 0 if k_size == 4 else 1
|
||||
mult = 2 ** n_downsampling
|
||||
model += [
|
||||
ResnetBlock(
|
||||
min(ngf * mult, opt.mc),
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
)
|
||||
]
|
||||
|
||||
if opt.spatio_size == 32:
|
||||
model += [
|
||||
nn.ConvTranspose2d(
|
||||
min(ngf * mult, opt.mc),
|
||||
min(int(ngf * mult / 2), opt.mc),
|
||||
kernel_size=k_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=o_pad,
|
||||
),
|
||||
norm_layer(min(int(ngf * mult / 2), opt.mc)),
|
||||
activation,
|
||||
]
|
||||
if opt.spatio_size == 64:
|
||||
model += [
|
||||
ResnetBlock(
|
||||
min(ngf * mult, opt.mc),
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
)
|
||||
]
|
||||
|
||||
for i in range(1, n_downsampling - opt.start_r):
|
||||
mult = 2 ** (n_downsampling - i)
|
||||
model += [
|
||||
ResnetBlock(
|
||||
min(ngf * mult, opt.mc),
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
)
|
||||
]
|
||||
model += [
|
||||
ResnetBlock(
|
||||
min(ngf * mult, opt.mc),
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=opt,
|
||||
)
|
||||
]
|
||||
model += [
|
||||
nn.ConvTranspose2d(
|
||||
min(ngf * mult, opt.mc),
|
||||
min(int(ngf * mult / 2), opt.mc),
|
||||
kernel_size=k_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=o_pad,
|
||||
),
|
||||
norm_layer(min(int(ngf * mult / 2), opt.mc)),
|
||||
activation,
|
||||
]
|
||||
for i in range(n_downsampling - opt.start_r, n_downsampling):
|
||||
mult = 2 ** (n_downsampling - i)
|
||||
model += [
|
||||
nn.ConvTranspose2d(
|
||||
min(ngf * mult, opt.mc),
|
||||
min(int(ngf * mult / 2), opt.mc),
|
||||
kernel_size=k_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=o_pad,
|
||||
),
|
||||
norm_layer(min(int(ngf * mult / 2), opt.mc)),
|
||||
activation,
|
||||
]
|
||||
if opt.use_segmentation_model:
|
||||
model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)]
|
||||
else:
|
||||
model += [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0),
|
||||
nn.Tanh(),
|
||||
]
|
||||
self.decoder = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input, flow="enc_dec"):
|
||||
if flow == "enc":
|
||||
return self.encoder(input)
|
||||
elif flow == "dec":
|
||||
return self.decoder(input)
|
||||
elif flow == "enc_dec":
|
||||
x = self.encoder(input)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
# Define a resnet block
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1
|
||||
):
|
||||
super(ResnetBlock, self).__init__()
|
||||
self.opt = opt
|
||||
self.dilation = dilation
|
||||
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
|
||||
|
||||
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
|
||||
conv_block = []
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
conv_block += [nn.ReflectionPad2d(self.dilation)]
|
||||
elif padding_type == "replicate":
|
||||
conv_block += [nn.ReplicationPad2d(self.dilation)]
|
||||
elif padding_type == "zero":
|
||||
p = self.dilation
|
||||
else:
|
||||
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
|
||||
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation),
|
||||
norm_layer(dim),
|
||||
activation,
|
||||
]
|
||||
if use_dropout:
|
||||
conv_block += [nn.Dropout(0.5)]
|
||||
|
||||
p = 0
|
||||
if padding_type == "reflect":
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == "replicate":
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == "zero":
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
|
||||
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)]
|
||||
|
||||
return nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
out = x + self.conv_block(x)
|
||||
return out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
|
||||
super(Encoder, self).__init__()
|
||||
self.output_nc = output_nc
|
||||
|
||||
model = [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
||||
norm_layer(ngf),
|
||||
nn.ReLU(True),
|
||||
]
|
||||
### downsample
|
||||
for i in range(n_downsampling):
|
||||
mult = 2 ** i
|
||||
model += [
|
||||
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(ngf * mult * 2),
|
||||
nn.ReLU(True),
|
||||
]
|
||||
|
||||
### upsample
|
||||
for i in range(n_downsampling):
|
||||
mult = 2 ** (n_downsampling - i)
|
||||
model += [
|
||||
nn.ConvTranspose2d(
|
||||
ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
),
|
||||
norm_layer(int(ngf * mult / 2)),
|
||||
nn.ReLU(True),
|
||||
]
|
||||
|
||||
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input, inst):
|
||||
outputs = self.model(input)
|
||||
|
||||
# instance-wise average pooling
|
||||
outputs_mean = outputs.clone()
|
||||
inst_list = np.unique(inst.cpu().numpy().astype(int))
|
||||
for i in inst_list:
|
||||
for b in range(input.size()[0]):
|
||||
indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4
|
||||
for j in range(self.output_nc):
|
||||
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]]
|
||||
mean_feat = torch.mean(output_ins).expand_as(output_ins)
|
||||
outputs_mean[
|
||||
indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]
|
||||
] = mean_feat
|
||||
return outputs_mean
|
||||
|
||||
|
||||
def SN(module, mode=True):
|
||||
if mode:
|
||||
return torch.nn.utils.spectral_norm(module)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class NonLocalBlock2D_with_mask_Res(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
inter_channels,
|
||||
mode="add",
|
||||
re_norm=False,
|
||||
temperature=1.0,
|
||||
use_self=False,
|
||||
cosin=False,
|
||||
):
|
||||
super(NonLocalBlock2D_with_mask_Res, self).__init__()
|
||||
|
||||
self.cosin = cosin
|
||||
self.renorm = re_norm
|
||||
self.in_channels = in_channels
|
||||
self.inter_channels = inter_channels
|
||||
|
||||
self.g = nn.Conv2d(
|
||||
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.W = nn.Conv2d(
|
||||
in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
# for pytorch 0.3.1
|
||||
# nn.init.constant(self.W.weight, 0)
|
||||
# nn.init.constant(self.W.bias, 0)
|
||||
# for pytorch 0.4.0
|
||||
nn.init.constant_(self.W.weight, 0)
|
||||
nn.init.constant_(self.W.bias, 0)
|
||||
self.theta = nn.Conv2d(
|
||||
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.phi = nn.Conv2d(
|
||||
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.mode = mode
|
||||
self.temperature = temperature
|
||||
self.use_self = use_self
|
||||
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
model = []
|
||||
for i in range(3):
|
||||
model += [
|
||||
ResnetBlock(
|
||||
inter_channels,
|
||||
padding_type="reflect",
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=None,
|
||||
)
|
||||
]
|
||||
self.res_block = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x, mask): ## The shape of mask is Batch*1*H*W
|
||||
batch_size = x.size(0)
|
||||
|
||||
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
|
||||
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
||||
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
|
||||
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
|
||||
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
|
||||
|
||||
if self.cosin:
|
||||
theta_x = F.normalize(theta_x, dim=2)
|
||||
phi_x = F.normalize(phi_x, dim=1)
|
||||
|
||||
f = torch.matmul(theta_x, phi_x)
|
||||
|
||||
f /= self.temperature
|
||||
|
||||
f_div_C = F.softmax(f, dim=2)
|
||||
|
||||
tmp = 1 - mask
|
||||
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
|
||||
mask[mask > 0] = 1.0
|
||||
mask = 1 - mask
|
||||
|
||||
tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
|
||||
mask *= tmp
|
||||
|
||||
mask_expand = mask.view(batch_size, 1, -1)
|
||||
mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1)
|
||||
|
||||
# mask = 1 - mask
|
||||
# mask=F.interpolate(mask,(x.size(2),x.size(3)))
|
||||
# mask_expand=mask.view(batch_size,1,-1)
|
||||
# mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1)
|
||||
|
||||
if self.use_self:
|
||||
mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0
|
||||
|
||||
# print(mask_expand.shape)
|
||||
# print(f_div_C.shape)
|
||||
|
||||
f_div_C = mask_expand * f_div_C
|
||||
if self.renorm:
|
||||
f_div_C = F.normalize(f_div_C, p=1, dim=2)
|
||||
|
||||
###########################
|
||||
|
||||
y = torch.matmul(f_div_C, g_x)
|
||||
|
||||
y = y.permute(0, 2, 1).contiguous()
|
||||
|
||||
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
|
||||
W_y = self.W(y)
|
||||
|
||||
W_y = self.res_block(W_y)
|
||||
|
||||
if self.mode == "combine":
|
||||
full_mask = mask.repeat(1, self.inter_channels, 1, 1)
|
||||
z = full_mask * x + (1 - full_mask) * W_y
|
||||
return z
|
||||
|
||||
|
||||
class MultiscaleDiscriminator(nn.Module):
|
||||
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
|
||||
use_sigmoid=False, num_D=3, getIntermFeat=False):
|
||||
super(MultiscaleDiscriminator, self).__init__()
|
||||
self.num_D = num_D
|
||||
self.n_layers = n_layers
|
||||
self.getIntermFeat = getIntermFeat
|
||||
|
||||
for i in range(num_D):
|
||||
netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
|
||||
if getIntermFeat:
|
||||
for j in range(n_layers+2):
|
||||
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
|
||||
else:
|
||||
setattr(self, 'layer'+str(i), netD.model)
|
||||
|
||||
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def singleD_forward(self, model, input):
|
||||
if self.getIntermFeat:
|
||||
result = [input]
|
||||
for i in range(len(model)):
|
||||
result.append(model[i](result[-1]))
|
||||
return result[1:]
|
||||
else:
|
||||
return [model(input)]
|
||||
|
||||
def forward(self, input):
|
||||
num_D = self.num_D
|
||||
result = []
|
||||
input_downsampled = input
|
||||
for i in range(num_D):
|
||||
if self.getIntermFeat:
|
||||
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
|
||||
else:
|
||||
model = getattr(self, 'layer'+str(num_D-1-i))
|
||||
result.append(self.singleD_forward(model, input_downsampled))
|
||||
if i != (num_D-1):
|
||||
input_downsampled = self.downsample(input_downsampled)
|
||||
return result
|
||||
|
||||
# Defines the PatchGAN discriminator with the specified arguments.
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
self.getIntermFeat = getIntermFeat
|
||||
self.n_layers = n_layers
|
||||
|
||||
kw = 4
|
||||
padw = int(np.ceil((kw-1.0)/2))
|
||||
sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]]
|
||||
|
||||
nf = ndf
|
||||
for n in range(1, n_layers):
|
||||
nf_prev = nf
|
||||
nf = min(nf * 2, 512)
|
||||
sequence += [[
|
||||
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN),
|
||||
norm_layer(nf), nn.LeakyReLU(0.2, True)
|
||||
]]
|
||||
|
||||
nf_prev = nf
|
||||
nf = min(nf * 2, 512)
|
||||
sequence += [[
|
||||
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN),
|
||||
norm_layer(nf),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]]
|
||||
|
||||
sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]]
|
||||
|
||||
if use_sigmoid:
|
||||
sequence += [[nn.Sigmoid()]]
|
||||
|
||||
if getIntermFeat:
|
||||
for n in range(len(sequence)):
|
||||
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
||||
else:
|
||||
sequence_stream = []
|
||||
for n in range(len(sequence)):
|
||||
sequence_stream += sequence[n]
|
||||
self.model = nn.Sequential(*sequence_stream)
|
||||
|
||||
def forward(self, input):
|
||||
if self.getIntermFeat:
|
||||
res = [input]
|
||||
for n in range(self.n_layers+2):
|
||||
model = getattr(self, 'model'+str(n))
|
||||
res.append(model(res[-1]))
|
||||
return res[1:]
|
||||
else:
|
||||
return self.model(input)
|
||||
|
||||
|
||||
|
||||
class Patch_Attention_4(nn.Module): ## While combine the feature map, use conv and mask
|
||||
def __init__(self, in_channels, inter_channels, patch_size):
|
||||
super(Patch_Attention_4, self).__init__()
|
||||
|
||||
self.patch_size=patch_size
|
||||
|
||||
|
||||
# self.g = nn.Conv2d(
|
||||
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
|
||||
# )
|
||||
|
||||
# self.W = nn.Conv2d(
|
||||
# in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0
|
||||
# )
|
||||
# # for pytorch 0.3.1
|
||||
# # nn.init.constant(self.W.weight, 0)
|
||||
# # nn.init.constant(self.W.bias, 0)
|
||||
# # for pytorch 0.4.0
|
||||
# nn.init.constant_(self.W.weight, 0)
|
||||
# nn.init.constant_(self.W.bias, 0)
|
||||
# self.theta = nn.Conv2d(
|
||||
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
|
||||
# )
|
||||
|
||||
# self.phi = nn.Conv2d(
|
||||
# in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0
|
||||
# )
|
||||
|
||||
self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True)
|
||||
norm_layer = get_norm_layer(norm_type="instance")
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
model = []
|
||||
for i in range(1):
|
||||
model += [
|
||||
ResnetBlock(
|
||||
inter_channels,
|
||||
padding_type="reflect",
|
||||
activation=activation,
|
||||
norm_layer=norm_layer,
|
||||
opt=None,
|
||||
)
|
||||
]
|
||||
self.res_block = nn.Sequential(*model)
|
||||
|
||||
def Hard_Compose(self, input, dim, index):
|
||||
# batch index select
|
||||
# input: [B,C,HW]
|
||||
# dim: scalar > 0
|
||||
# index: [B, HW]
|
||||
views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))]
|
||||
expanse = list(input.size())
|
||||
expanse[0] = -1
|
||||
expanse[dim] = -1
|
||||
index = index.view(views).expand(expanse)
|
||||
return torch.gather(input, dim, index)
|
||||
|
||||
def forward(self, z, mask): ## The shape of mask is Batch*1*H*W
|
||||
|
||||
x=self.res_block(z)
|
||||
|
||||
b,c,h,w=x.shape
|
||||
|
||||
## mask resize + dilation
|
||||
# tmp = 1 - mask
|
||||
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
|
||||
mask[mask > 0] = 1.0
|
||||
|
||||
# mask = 1 - mask
|
||||
# tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
|
||||
# mask *= tmp
|
||||
# mask=1-mask
|
||||
## 1: mask position 0: non-mask
|
||||
|
||||
mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
|
||||
non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()
|
||||
all_patch_num=h*w/self.patch_size/self.patch_size
|
||||
non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1)
|
||||
|
||||
x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
|
||||
y_unfold=x_unfold.permute(0,2,1)
|
||||
x_unfold_normalized=F.normalize(x_unfold,dim=1)
|
||||
y_unfold_normalized=F.normalize(y_unfold,dim=2)
|
||||
correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized)
|
||||
correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9)
|
||||
correlation_matrix=F.softmax(correlation_matrix,dim=2)
|
||||
|
||||
# print(correlation_matrix)
|
||||
|
||||
R, max_arg=torch.max(correlation_matrix,dim=2)
|
||||
|
||||
composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg)
|
||||
composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size)
|
||||
|
||||
concat_1=torch.cat((z,composed_fold,mask),dim=1)
|
||||
concat_1=self.F_Combine(concat_1)
|
||||
|
||||
return concat_1
|
||||
|
||||
def inference_forward(self,z,mask): ## Reduce the extra memory cost
|
||||
|
||||
|
||||
x=self.res_block(z)
|
||||
|
||||
b,c,h,w=x.shape
|
||||
|
||||
## mask resize + dilation
|
||||
# tmp = 1 - mask
|
||||
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear")
|
||||
mask[mask > 0] = 1.0
|
||||
# mask = 1 - mask
|
||||
# tmp = F.interpolate(tmp, (x.size(2), x.size(3)))
|
||||
# mask *= tmp
|
||||
# mask=1-mask
|
||||
## 1: mask position 0: non-mask
|
||||
|
||||
mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
|
||||
non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num
|
||||
|
||||
all_patch_num=h*w/self.patch_size/self.patch_size
|
||||
|
||||
mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0]
|
||||
|
||||
|
||||
if len(mask_index)==0: ## No mask patch is selected, no attention is needed
|
||||
|
||||
composed_fold=x
|
||||
|
||||
else:
|
||||
|
||||
unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0]
|
||||
|
||||
x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size)
|
||||
|
||||
Query_Patch=torch.index_select(x_unfold,2,mask_index)
|
||||
Key_Patch=torch.index_select(x_unfold,2,unmask_index)
|
||||
|
||||
Query_Patch=Query_Patch.permute(0,2,1)
|
||||
Query_Patch_normalized=F.normalize(Query_Patch,dim=2)
|
||||
Key_Patch_normalized=F.normalize(Key_Patch,dim=1)
|
||||
|
||||
correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized)
|
||||
correlation_matrix=F.softmax(correlation_matrix,dim=2)
|
||||
|
||||
|
||||
R, max_arg=torch.max(correlation_matrix,dim=2)
|
||||
|
||||
composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg)
|
||||
x_unfold[:,:,mask_index]=composed_unfold
|
||||
composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size)
|
||||
|
||||
concat_1=torch.cat((z,composed_fold,mask),dim=1)
|
||||
concat_1=self.F_Combine(concat_1)
|
||||
|
||||
|
||||
return concat_1
|
||||
|
||||
##############################################################################
|
||||
# Losses
|
||||
##############################################################################
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
|
||||
tensor=torch.FloatTensor):
|
||||
super(GANLoss, self).__init__()
|
||||
self.real_label = target_real_label
|
||||
self.fake_label = target_fake_label
|
||||
self.real_label_var = None
|
||||
self.fake_label_var = None
|
||||
self.Tensor = tensor
|
||||
if use_lsgan:
|
||||
self.loss = nn.MSELoss()
|
||||
else:
|
||||
self.loss = nn.BCELoss()
|
||||
|
||||
def get_target_tensor(self, input, target_is_real):
|
||||
target_tensor = None
|
||||
if target_is_real:
|
||||
create_label = ((self.real_label_var is None) or
|
||||
(self.real_label_var.numel() != input.numel()))
|
||||
if create_label:
|
||||
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
||||
self.real_label_var = Variable(real_tensor, requires_grad=False)
|
||||
target_tensor = self.real_label_var
|
||||
else:
|
||||
create_label = ((self.fake_label_var is None) or
|
||||
(self.fake_label_var.numel() != input.numel()))
|
||||
if create_label:
|
||||
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
||||
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
|
||||
target_tensor = self.fake_label_var
|
||||
return target_tensor
|
||||
|
||||
def __call__(self, input, target_is_real):
|
||||
if isinstance(input[0], list):
|
||||
loss = 0
|
||||
for input_i in input:
|
||||
pred = input_i[-1]
|
||||
target_tensor = self.get_target_tensor(pred, target_is_real)
|
||||
loss += self.loss(pred, target_tensor)
|
||||
return loss
|
||||
else:
|
||||
target_tensor = self.get_target_tensor(input[-1], target_is_real)
|
||||
return self.loss(input[-1], target_tensor)
|
||||
|
||||
|
||||
|
||||
|
||||
####################################### VGG Loss
|
||||
|
||||
from torchvision import models
|
||||
class VGG19_torch(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False):
|
||||
super(VGG19_torch, self).__init__()
|
||||
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(2, 7):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(7, 12):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 21):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(21, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h_relu1 = self.slice1(X)
|
||||
h_relu2 = self.slice2(h_relu1)
|
||||
h_relu3 = self.slice3(h_relu2)
|
||||
h_relu4 = self.slice4(h_relu3)
|
||||
h_relu5 = self.slice5(h_relu4)
|
||||
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||
return out
|
||||
|
||||
class VGGLoss_torch(nn.Module):
|
||||
def __init__(self, gpu_ids):
|
||||
super(VGGLoss_torch, self).__init__()
|
||||
self.vgg = VGG19_torch().cuda()
|
||||
self.criterion = nn.L1Loss()
|
||||
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
||||
|
||||
def forward(self, x, y):
|
||||
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
||||
loss = 0
|
||||
for i in range(len(x_vgg)):
|
||||
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
||||
return loss
|
||||
@@ -1,333 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from util.image_pool import ImagePool
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
class Pix2PixHDModel(BaseModel):
|
||||
def name(self):
|
||||
return 'Pix2PixHDModel'
|
||||
|
||||
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_L1):
|
||||
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True,use_smooth_L1)
|
||||
def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake,smooth_l1):
|
||||
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg, g_kl, d_real,d_fake,smooth_l1),flags) if f]
|
||||
return loss_filter
|
||||
|
||||
def initialize(self, opt):
|
||||
BaseModel.initialize(self, opt)
|
||||
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.isTrain = opt.isTrain
|
||||
self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false
|
||||
self.gen_features = self.use_features and not self.opt.load_features ## it is also false
|
||||
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel #
|
||||
|
||||
##### define networks
|
||||
# Generator network
|
||||
netG_input_nc = input_nc
|
||||
if not opt.no_instance:
|
||||
netG_input_nc += 1
|
||||
if self.use_features:
|
||||
netG_input_nc += opt.feat_num
|
||||
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size,
|
||||
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
|
||||
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt)
|
||||
|
||||
# Discriminator network
|
||||
if self.isTrain:
|
||||
use_sigmoid = opt.no_lsgan
|
||||
netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc
|
||||
if not opt.no_instance:
|
||||
netD_input_nc += 1
|
||||
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,
|
||||
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
|
||||
|
||||
if self.opt.verbose:
|
||||
print('---------- Networks initialized -------------')
|
||||
|
||||
# load networks
|
||||
if not self.isTrain or opt.continue_train or opt.load_pretrain:
|
||||
pretrained_path = '' if not self.isTrain else opt.load_pretrain
|
||||
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
||||
|
||||
print("---------- G Networks reloaded -------------")
|
||||
if self.isTrain:
|
||||
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
|
||||
print("---------- D Networks reloaded -------------")
|
||||
|
||||
|
||||
if self.gen_features:
|
||||
self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
|
||||
|
||||
# set loss functions and optimizers
|
||||
if self.isTrain:
|
||||
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0!
|
||||
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
|
||||
self.fake_pool = ImagePool(opt.pool_size)
|
||||
self.old_lr = opt.lr
|
||||
|
||||
# define loss functions
|
||||
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1)
|
||||
|
||||
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
|
||||
self.criterionFeat = torch.nn.L1Loss()
|
||||
|
||||
# self.criterionImage = torch.nn.SmoothL1Loss()
|
||||
if not opt.no_vgg_loss:
|
||||
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)
|
||||
|
||||
|
||||
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG', 'G_KL', 'D_real', 'D_fake', 'Smooth_L1')
|
||||
|
||||
# initialize optimizers
|
||||
# optimizer G
|
||||
params = list(self.netG.parameters())
|
||||
if self.gen_features:
|
||||
params += list(self.netE.parameters())
|
||||
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
# optimizer D
|
||||
params = list(self.netD.parameters())
|
||||
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
print("---------- Optimizers initialized -------------")
|
||||
|
||||
if opt.continue_train:
|
||||
self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch)
|
||||
self.load_optimizer(self.optimizer_G, "G", opt.which_epoch)
|
||||
for param_groups in self.optimizer_D.param_groups:
|
||||
self.old_lr=param_groups['lr']
|
||||
|
||||
print("---------- Optimizers reloaded -------------")
|
||||
print("---------- Current LR is %.8f -------------"%(self.old_lr))
|
||||
|
||||
## We also want to re-load the parameters of optimizer.
|
||||
|
||||
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
|
||||
if self.opt.label_nc == 0:
|
||||
input_label = label_map.data.cuda()
|
||||
else:
|
||||
# create one-hot vector for label map
|
||||
size = label_map.size()
|
||||
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
|
||||
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
|
||||
if self.opt.data_type == 16:
|
||||
input_label = input_label.half()
|
||||
|
||||
# get edges from instance map
|
||||
if not self.opt.no_instance:
|
||||
inst_map = inst_map.data.cuda()
|
||||
edge_map = self.get_edges(inst_map)
|
||||
input_label = torch.cat((input_label, edge_map), dim=1)
|
||||
input_label = Variable(input_label, volatile=infer)
|
||||
|
||||
# real images for training
|
||||
if real_image is not None:
|
||||
real_image = Variable(real_image.data.cuda())
|
||||
|
||||
# instance map for feature encoding
|
||||
if self.use_features:
|
||||
# get precomputed feature maps
|
||||
if self.opt.load_features:
|
||||
feat_map = Variable(feat_map.data.cuda())
|
||||
if self.opt.label_feat:
|
||||
inst_map = label_map.cuda()
|
||||
|
||||
return input_label, inst_map, real_image, feat_map
|
||||
|
||||
def discriminate(self, input_label, test_image, use_pool=False):
|
||||
if input_label is None:
|
||||
input_concat = test_image.detach()
|
||||
else:
|
||||
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
|
||||
if use_pool:
|
||||
fake_query = self.fake_pool.query(input_concat)
|
||||
return self.netD.forward(fake_query)
|
||||
else:
|
||||
return self.netD.forward(input_concat)
|
||||
|
||||
def forward(self, label, inst, image, feat, infer=False):
|
||||
# Encode Inputs
|
||||
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
|
||||
|
||||
# Fake Generation
|
||||
if self.use_features:
|
||||
if not self.opt.load_features:
|
||||
feat_map = self.netE.forward(real_image, inst_map)
|
||||
input_concat = torch.cat((input_label, feat_map), dim=1)
|
||||
else:
|
||||
input_concat = input_label
|
||||
hiddens = self.netG.forward(input_concat, 'enc')
|
||||
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
|
||||
# This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
|
||||
# We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py)
|
||||
fake_image = self.netG.forward(hiddens + noise, 'dec')
|
||||
|
||||
if self.opt.no_cgan:
|
||||
# Fake Detection and Loss
|
||||
pred_fake_pool = self.discriminate(None, fake_image, use_pool=True)
|
||||
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
|
||||
|
||||
# Real Detection and Loss
|
||||
pred_real = self.discriminate(None, real_image)
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
|
||||
# GAN loss (Fake Passability Loss)
|
||||
pred_fake = self.netD.forward(fake_image)
|
||||
loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
else:
|
||||
# Fake Detection and Loss
|
||||
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
|
||||
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
|
||||
|
||||
# Real Detection and Loss
|
||||
pred_real = self.discriminate(input_label, real_image)
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
|
||||
# GAN loss (Fake Passability Loss)
|
||||
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
|
||||
loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
|
||||
|
||||
loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl
|
||||
|
||||
# GAN feature matching loss
|
||||
loss_G_GAN_Feat = 0
|
||||
if not self.opt.no_ganFeat_loss:
|
||||
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
|
||||
D_weights = 1.0 / self.opt.num_D
|
||||
for i in range(self.opt.num_D):
|
||||
for j in range(len(pred_fake[i])-1):
|
||||
loss_G_GAN_Feat += D_weights * feat_weights * \
|
||||
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
|
||||
|
||||
# VGG feature matching loss
|
||||
loss_G_VGG = 0
|
||||
if not self.opt.no_vgg_loss:
|
||||
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
|
||||
|
||||
|
||||
smooth_l1_loss=0
|
||||
|
||||
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,smooth_l1_loss ), None if not infer else fake_image ]
|
||||
|
||||
def inference(self, label, inst, image=None, feat=None):
|
||||
# Encode Inputs
|
||||
image = Variable(image) if image is not None else None
|
||||
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
|
||||
|
||||
# Fake Generation
|
||||
if self.use_features:
|
||||
if self.opt.use_encoded_image:
|
||||
# encode the real image to get feature map
|
||||
feat_map = self.netE.forward(real_image, inst_map)
|
||||
else:
|
||||
# sample clusters from precomputed features
|
||||
feat_map = self.sample_features(inst_map)
|
||||
input_concat = torch.cat((input_label, feat_map), dim=1)
|
||||
else:
|
||||
input_concat = input_label
|
||||
|
||||
if torch.__version__.startswith('0.4'):
|
||||
with torch.no_grad():
|
||||
fake_image = self.netG.forward(input_concat)
|
||||
else:
|
||||
fake_image = self.netG.forward(input_concat)
|
||||
return fake_image
|
||||
|
||||
def sample_features(self, inst):
|
||||
# read precomputed feature clusters
|
||||
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
|
||||
features_clustered = np.load(cluster_path, encoding='latin1').item()
|
||||
|
||||
# randomly sample from the feature clusters
|
||||
inst_np = inst.cpu().numpy().astype(int)
|
||||
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
|
||||
for i in np.unique(inst_np):
|
||||
label = i if i < 1000 else i//1000
|
||||
if label in features_clustered:
|
||||
feat = features_clustered[label]
|
||||
cluster_idx = np.random.randint(0, feat.shape[0])
|
||||
|
||||
idx = (inst == int(i)).nonzero()
|
||||
for k in range(self.opt.feat_num):
|
||||
feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
|
||||
if self.opt.data_type==16:
|
||||
feat_map = feat_map.half()
|
||||
return feat_map
|
||||
|
||||
def encode_features(self, image, inst):
|
||||
image = Variable(image.cuda(), volatile=True)
|
||||
feat_num = self.opt.feat_num
|
||||
h, w = inst.size()[2], inst.size()[3]
|
||||
block_num = 32
|
||||
feat_map = self.netE.forward(image, inst.cuda())
|
||||
inst_np = inst.cpu().numpy().astype(int)
|
||||
feature = {}
|
||||
for i in range(self.opt.label_nc):
|
||||
feature[i] = np.zeros((0, feat_num+1))
|
||||
for i in np.unique(inst_np):
|
||||
label = i if i < 1000 else i//1000
|
||||
idx = (inst == int(i)).nonzero()
|
||||
num = idx.size()[0]
|
||||
idx = idx[num//2,:]
|
||||
val = np.zeros((1, feat_num+1))
|
||||
for k in range(feat_num):
|
||||
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
|
||||
val[0, feat_num] = float(num) / (h * w // block_num)
|
||||
feature[label] = np.append(feature[label], val, axis=0)
|
||||
return feature
|
||||
|
||||
def get_edges(self, t):
|
||||
edge = torch.cuda.ByteTensor(t.size()).zero_()
|
||||
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
|
||||
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
|
||||
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
|
||||
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
|
||||
if self.opt.data_type==16:
|
||||
return edge.half()
|
||||
else:
|
||||
return edge.float()
|
||||
|
||||
def save(self, which_epoch):
|
||||
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
|
||||
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
|
||||
|
||||
self.save_optimizer(self.optimizer_G,"G",which_epoch)
|
||||
self.save_optimizer(self.optimizer_D,"D",which_epoch)
|
||||
|
||||
if self.gen_features:
|
||||
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
|
||||
|
||||
def update_fixed_params(self):
|
||||
|
||||
params = list(self.netG.parameters())
|
||||
if self.gen_features:
|
||||
params += list(self.netE.parameters())
|
||||
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
|
||||
if self.opt.verbose:
|
||||
print('------------ Now also finetuning global generator -----------')
|
||||
|
||||
def update_learning_rate(self):
|
||||
lrd = self.opt.lr / self.opt.niter_decay
|
||||
lr = self.old_lr - lrd
|
||||
for param_group in self.optimizer_D.param_groups:
|
||||
param_group['lr'] = lr
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
param_group['lr'] = lr
|
||||
if self.opt.verbose:
|
||||
print('update learning rate: %f -> %f' % (self.old_lr, lr))
|
||||
self.old_lr = lr
|
||||
|
||||
|
||||
class InferenceModel(Pix2PixHDModel):
|
||||
def forward(self, inp):
|
||||
label, inst = inp
|
||||
return self.inference(label, inst)
|
||||
@@ -1,372 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from util.image_pool import ImagePool
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
|
||||
class Pix2PixHDModel(BaseModel):
|
||||
def name(self):
|
||||
return 'Pix2PixHDModel'
|
||||
|
||||
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
|
||||
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True, True)
|
||||
|
||||
def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake):
|
||||
return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake), flags) if f]
|
||||
|
||||
return loss_filter
|
||||
|
||||
def initialize(self, opt):
|
||||
BaseModel.initialize(self, opt)
|
||||
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.isTrain = opt.isTrain
|
||||
self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false
|
||||
self.gen_features = self.use_features and not self.opt.load_features ## it is also false
|
||||
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel #
|
||||
|
||||
##### define networks
|
||||
# Generator network
|
||||
netG_input_nc = input_nc
|
||||
if not opt.no_instance:
|
||||
netG_input_nc += 1
|
||||
if self.use_features:
|
||||
netG_input_nc += opt.feat_num
|
||||
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size,
|
||||
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
|
||||
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt)
|
||||
|
||||
# Discriminator network
|
||||
if self.isTrain:
|
||||
use_sigmoid = opt.no_lsgan
|
||||
netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc
|
||||
if not opt.no_instance:
|
||||
netD_input_nc += 1
|
||||
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt,opt.norm, use_sigmoid,
|
||||
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
|
||||
|
||||
self.feat_D=networks.define_D(64, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid,
|
||||
1, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
|
||||
|
||||
if self.opt.verbose:
|
||||
print('---------- Networks initialized -------------')
|
||||
|
||||
# load networks
|
||||
if not self.isTrain or opt.continue_train or opt.load_pretrain:
|
||||
pretrained_path = '' if not self.isTrain else opt.load_pretrain
|
||||
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
||||
|
||||
print("---------- G Networks reloaded -------------")
|
||||
if self.isTrain:
|
||||
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
|
||||
self.load_network(self.feat_D, 'feat_D', opt.which_epoch, pretrained_path)
|
||||
print("---------- D Networks reloaded -------------")
|
||||
|
||||
|
||||
# set loss functions and optimizers
|
||||
if self.isTrain:
|
||||
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0!
|
||||
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
|
||||
self.fake_pool = ImagePool(opt.pool_size)
|
||||
self.old_lr = opt.lr
|
||||
|
||||
# define loss functions
|
||||
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
|
||||
|
||||
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
|
||||
self.criterionFeat = torch.nn.L1Loss()
|
||||
if not opt.no_vgg_loss:
|
||||
self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids)
|
||||
|
||||
# Names so we can breakout loss
|
||||
self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_KL', 'D_real', 'D_fake', 'G_featD', 'featD_real','featD_fake')
|
||||
|
||||
# initialize optimizers
|
||||
# optimizer G
|
||||
params = list(self.netG.parameters())
|
||||
if self.gen_features:
|
||||
params += list(self.netE.parameters())
|
||||
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
# optimizer D
|
||||
params = list(self.netD.parameters())
|
||||
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
params = list(self.feat_D.parameters())
|
||||
self.optimizer_featD = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
print("---------- Optimizers initialized -------------")
|
||||
|
||||
if opt.continue_train:
|
||||
self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch)
|
||||
self.load_optimizer(self.optimizer_G, "G", opt.which_epoch)
|
||||
self.load_optimizer(self.optimizer_featD,'featD',opt.which_epoch)
|
||||
for param_groups in self.optimizer_D.param_groups:
|
||||
self.old_lr = param_groups['lr']
|
||||
|
||||
print("---------- Optimizers reloaded -------------")
|
||||
print("---------- Current LR is %.8f -------------" % (self.old_lr))
|
||||
|
||||
## We also want to re-load the parameters of optimizer.
|
||||
|
||||
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
|
||||
if self.opt.label_nc == 0:
|
||||
input_label = label_map.data.cuda()
|
||||
else:
|
||||
# create one-hot vector for label map
|
||||
size = label_map.size()
|
||||
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
|
||||
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
|
||||
if self.opt.data_type == 16:
|
||||
input_label = input_label.half()
|
||||
|
||||
# get edges from instance map
|
||||
if not self.opt.no_instance:
|
||||
inst_map = inst_map.data.cuda()
|
||||
edge_map = self.get_edges(inst_map)
|
||||
input_label = torch.cat((input_label, edge_map), dim=1)
|
||||
input_label = Variable(input_label, volatile=infer)
|
||||
|
||||
# real images for training
|
||||
if real_image is not None:
|
||||
real_image = Variable(real_image.data.cuda())
|
||||
|
||||
# instance map for feature encoding
|
||||
if self.use_features:
|
||||
# get precomputed feature maps
|
||||
if self.opt.load_features:
|
||||
feat_map = Variable(feat_map.data.cuda())
|
||||
if self.opt.label_feat:
|
||||
inst_map = label_map.cuda()
|
||||
|
||||
return input_label, inst_map, real_image, feat_map
|
||||
|
||||
def discriminate(self, input_label, test_image, use_pool=False):
|
||||
if input_label is None:
|
||||
input_concat = test_image.detach()
|
||||
else:
|
||||
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
|
||||
if use_pool:
|
||||
fake_query = self.fake_pool.query(input_concat)
|
||||
return self.netD.forward(fake_query)
|
||||
else:
|
||||
return self.netD.forward(input_concat)
|
||||
|
||||
def feat_discriminate(self,input):
|
||||
|
||||
return self.feat_D.forward(input.detach())
|
||||
|
||||
|
||||
def forward(self, label, inst, image, feat, infer=False):
|
||||
# Encode Inputs
|
||||
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
|
||||
|
||||
# Fake Generation
|
||||
if self.use_features:
|
||||
if not self.opt.load_features:
|
||||
feat_map = self.netE.forward(real_image, inst_map)
|
||||
input_concat = torch.cat((input_label, feat_map), dim=1)
|
||||
else:
|
||||
input_concat = input_label
|
||||
hiddens = self.netG.forward(input_concat, 'enc')
|
||||
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
|
||||
# This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
|
||||
# We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py)
|
||||
fake_image = self.netG.forward(hiddens + noise, 'dec')
|
||||
|
||||
####################
|
||||
##### GAN for the intermediate feature
|
||||
real_old_feat =[]
|
||||
syn_feat = []
|
||||
for index,x in enumerate(inst):
|
||||
if x==1:
|
||||
real_old_feat.append(hiddens[index].unsqueeze(0))
|
||||
else:
|
||||
syn_feat.append(hiddens[index].unsqueeze(0))
|
||||
L=min(len(real_old_feat),len(syn_feat))
|
||||
real_old_feat=real_old_feat[:L]
|
||||
syn_feat=syn_feat[:L]
|
||||
real_old_feat=torch.cat(real_old_feat,0)
|
||||
syn_feat=torch.cat(syn_feat,0)
|
||||
|
||||
pred_fake_feat=self.feat_discriminate(real_old_feat)
|
||||
loss_featD_fake = self.criterionGAN(pred_fake_feat, False)
|
||||
pred_real_feat=self.feat_discriminate(syn_feat)
|
||||
loss_featD_real = self.criterionGAN(pred_real_feat, True)
|
||||
|
||||
pred_fake_feat_G=self.feat_D.forward(real_old_feat)
|
||||
loss_G_featD=self.criterionGAN(pred_fake_feat_G,True)
|
||||
|
||||
|
||||
#####################################
|
||||
if self.opt.no_cgan:
|
||||
# Fake Detection and Loss
|
||||
pred_fake_pool = self.discriminate(None, fake_image, use_pool=True)
|
||||
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
|
||||
|
||||
# Real Detection and Loss
|
||||
pred_real = self.discriminate(None, real_image)
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
|
||||
# GAN loss (Fake Passability Loss)
|
||||
pred_fake = self.netD.forward(fake_image)
|
||||
loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
else:
|
||||
# Fake Detection and Loss
|
||||
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
|
||||
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
|
||||
|
||||
# Real Detection and Loss
|
||||
pred_real = self.discriminate(input_label, real_image)
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
|
||||
# GAN loss (Fake Passability Loss)
|
||||
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
|
||||
loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
|
||||
loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl
|
||||
|
||||
# GAN feature matching loss
|
||||
loss_G_GAN_Feat = 0
|
||||
if not self.opt.no_ganFeat_loss:
|
||||
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
|
||||
D_weights = 1.0 / self.opt.num_D
|
||||
for i in range(self.opt.num_D):
|
||||
for j in range(len(pred_fake[i]) - 1):
|
||||
loss_G_GAN_Feat += D_weights * feat_weights * \
|
||||
self.criterionFeat(pred_fake[i][j],
|
||||
pred_real[i][j].detach()) * self.opt.lambda_feat
|
||||
|
||||
# VGG feature matching loss
|
||||
loss_G_VGG = 0
|
||||
if not self.opt.no_vgg_loss:
|
||||
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
|
||||
|
||||
# Only return the fake_B image if necessary to save BW
|
||||
return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,loss_G_featD, loss_featD_real, loss_featD_fake),
|
||||
None if not infer else fake_image]
|
||||
|
||||
def inference(self, label, inst, image=None, feat=None):
|
||||
# Encode Inputs
|
||||
image = Variable(image) if image is not None else None
|
||||
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
|
||||
|
||||
# Fake Generation
|
||||
if self.use_features:
|
||||
if self.opt.use_encoded_image:
|
||||
# encode the real image to get feature map
|
||||
feat_map = self.netE.forward(real_image, inst_map)
|
||||
else:
|
||||
# sample clusters from precomputed features
|
||||
feat_map = self.sample_features(inst_map)
|
||||
input_concat = torch.cat((input_label, feat_map), dim=1)
|
||||
else:
|
||||
input_concat = input_label
|
||||
|
||||
if torch.__version__.startswith('0.4'):
|
||||
with torch.no_grad():
|
||||
fake_image = self.netG.forward(input_concat)
|
||||
else:
|
||||
fake_image = self.netG.forward(input_concat)
|
||||
return fake_image
|
||||
|
||||
def sample_features(self, inst):
|
||||
# read precomputed feature clusters
|
||||
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
|
||||
features_clustered = np.load(cluster_path, encoding='latin1').item()
|
||||
|
||||
# randomly sample from the feature clusters
|
||||
inst_np = inst.cpu().numpy().astype(int)
|
||||
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
|
||||
for i in np.unique(inst_np):
|
||||
label = i if i < 1000 else i // 1000
|
||||
if label in features_clustered:
|
||||
feat = features_clustered[label]
|
||||
cluster_idx = np.random.randint(0, feat.shape[0])
|
||||
|
||||
idx = (inst == int(i)).nonzero()
|
||||
for k in range(self.opt.feat_num):
|
||||
feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k]
|
||||
if self.opt.data_type == 16:
|
||||
feat_map = feat_map.half()
|
||||
return feat_map
|
||||
|
||||
def encode_features(self, image, inst):
|
||||
image = Variable(image.cuda(), volatile=True)
|
||||
feat_num = self.opt.feat_num
|
||||
h, w = inst.size()[2], inst.size()[3]
|
||||
block_num = 32
|
||||
feat_map = self.netE.forward(image, inst.cuda())
|
||||
inst_np = inst.cpu().numpy().astype(int)
|
||||
feature = {}
|
||||
for i in range(self.opt.label_nc):
|
||||
feature[i] = np.zeros((0, feat_num + 1))
|
||||
for i in np.unique(inst_np):
|
||||
label = i if i < 1000 else i // 1000
|
||||
idx = (inst == int(i)).nonzero()
|
||||
num = idx.size()[0]
|
||||
idx = idx[num // 2, :]
|
||||
val = np.zeros((1, feat_num + 1))
|
||||
for k in range(feat_num):
|
||||
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
|
||||
val[0, feat_num] = float(num) / (h * w // block_num)
|
||||
feature[label] = np.append(feature[label], val, axis=0)
|
||||
return feature
|
||||
|
||||
def get_edges(self, t):
|
||||
edge = torch.cuda.ByteTensor(t.size()).zero_()
|
||||
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
||||
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
||||
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
||||
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
||||
if self.opt.data_type == 16:
|
||||
return edge.half()
|
||||
else:
|
||||
return edge.float()
|
||||
|
||||
def save(self, which_epoch):
|
||||
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
|
||||
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
|
||||
self.save_network(self.feat_D,'featD',which_epoch,self.gpu_ids)
|
||||
|
||||
self.save_optimizer(self.optimizer_G, "G", which_epoch)
|
||||
self.save_optimizer(self.optimizer_D, "D", which_epoch)
|
||||
self.save_optimizer(self.optimizer_featD,'featD',which_epoch)
|
||||
|
||||
if self.gen_features:
|
||||
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
|
||||
|
||||
def update_fixed_params(self):
|
||||
|
||||
params = list(self.netG.parameters())
|
||||
if self.gen_features:
|
||||
params += list(self.netE.parameters())
|
||||
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
|
||||
if self.opt.verbose:
|
||||
print('------------ Now also finetuning global generator -----------')
|
||||
|
||||
def update_learning_rate(self):
|
||||
lrd = self.opt.lr / self.opt.niter_decay
|
||||
lr = self.old_lr - lrd
|
||||
for param_group in self.optimizer_D.param_groups:
|
||||
param_group['lr'] = lr
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
param_group['lr'] = lr
|
||||
for param_group in self.optimizer_featD.param_groups:
|
||||
param_group['lr'] = lr
|
||||
if self.opt.verbose:
|
||||
print('update learning rate: %f -> %f' % (self.old_lr, lr))
|
||||
self.old_lr = lr
|
||||
|
||||
|
||||
class InferenceModel(Pix2PixHDModel):
|
||||
def forward(self, inp):
|
||||
label, inst = inp
|
||||
return self.inference(label, inst)
|
||||
@@ -1,469 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
|
||||
class BaseOptions:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self):
|
||||
# experiment specifics
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
default="label2city",
|
||||
help="name of the experiment. It decides where to store samples and models",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--gpu_ids",
|
||||
type=str,
|
||||
default="0",
|
||||
help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--checkpoints_dir",
|
||||
type=str,
|
||||
default="./checkpoints",
|
||||
help="models are saved here",
|
||||
) ## note: to add this param when using philly
|
||||
# self.parser.add_argument('--project_dir', type=str, default='./', help='the project is saved here') ################### This is necessary for philly
|
||||
self.parser.add_argument(
|
||||
"--outputs_dir", type=str, default="./outputs", help="models are saved here"
|
||||
) ## note: to add this param when using philly Please end with '/'
|
||||
self.parser.add_argument(
|
||||
"--model", type=str, default="pix2pixHD", help="which model to use"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--norm",
|
||||
type=str,
|
||||
default="instance",
|
||||
help="instance normalization or batch normalization",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--use_dropout", action="store_true", help="use dropout for the generator"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--data_type",
|
||||
default=32,
|
||||
type=int,
|
||||
choices=[8, 16, 32],
|
||||
help="Supported data type i.e. 8, 16, 32 bit",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--verbose", action="store_true", default=False, help="toggles verbose"
|
||||
)
|
||||
|
||||
# input/output sizes
|
||||
self.parser.add_argument(
|
||||
"--batchSize", type=int, default=1, help="input batch size"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--loadSize", type=int, default=1024, help="scale images to this size"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--fineSize", type=int, default=512, help="then crop to this size"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--label_nc", type=int, default=35, help="# of input label channels"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--input_nc", type=int, default=3, help="# of input image channels"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--output_nc", type=int, default=3, help="# of output image channels"
|
||||
)
|
||||
|
||||
# for setting inputs
|
||||
self.parser.add_argument(
|
||||
"--dataroot", type=str, default="./datasets/cityscapes/"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--resize_or_crop",
|
||||
type=str,
|
||||
default="scale_width",
|
||||
help="scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--serial_batches",
|
||||
action="store_true",
|
||||
help="if true, takes images in order to make batches, otherwise takes them randomly",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--no_flip",
|
||||
action="store_true",
|
||||
help="if specified, do not flip the images for data argumentation",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--nThreads", default=2, type=int, help="# threads for loading data"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--max_dataset_size",
|
||||
type=int,
|
||||
default=float("inf"),
|
||||
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
|
||||
)
|
||||
|
||||
# for displays
|
||||
self.parser.add_argument(
|
||||
"--display_winsize", type=int, default=512, help="display window size"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--tf_log",
|
||||
action="store_true",
|
||||
help="if specified, use tensorboard logging. Requires tensorflow installed",
|
||||
)
|
||||
|
||||
# for generator
|
||||
self.parser.add_argument(
|
||||
"--netG", type=str, default="global", help="selects model to use for netG"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--ngf", type=int, default=64, help="# of gen filters in first conv layer"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--k_size", type=int, default=3, help="# kernel size conv layer"
|
||||
)
|
||||
self.parser.add_argument("--use_v2", action="store_true", help="use DCDCv2")
|
||||
self.parser.add_argument("--mc", type=int, default=1024, help="# max channel")
|
||||
self.parser.add_argument(
|
||||
"--start_r", type=int, default=3, help="start layer to use resblock"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--n_downsample_global",
|
||||
type=int,
|
||||
default=4,
|
||||
help="number of downsampling layers in netG",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--n_blocks_global",
|
||||
type=int,
|
||||
default=9,
|
||||
help="number of residual blocks in the global generator network",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--n_blocks_local",
|
||||
type=int,
|
||||
default=3,
|
||||
help="number of residual blocks in the local enhancer network",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--n_local_enhancers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of local enhancers to use",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--niter_fix_global",
|
||||
type=int,
|
||||
default=0,
|
||||
help="number of epochs that we only train the outmost local enhancer",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--load_pretrain",
|
||||
type=str,
|
||||
default="",
|
||||
help="load the pretrained model from the specified location",
|
||||
)
|
||||
|
||||
# for instance-wise features
|
||||
self.parser.add_argument(
|
||||
"--no_instance",
|
||||
action="store_true",
|
||||
help="if specified, do *not* add instance map as input",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--instance_feat",
|
||||
action="store_true",
|
||||
help="if specified, add encoded instance features as input",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--label_feat",
|
||||
action="store_true",
|
||||
help="if specified, add encoded label features as input",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--feat_num", type=int, default=3, help="vector length for encoded features"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--load_features",
|
||||
action="store_true",
|
||||
help="if specified, load precomputed feature maps",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--n_downsample_E",
|
||||
type=int,
|
||||
default=4,
|
||||
help="# of downsampling layers in encoder",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--nef",
|
||||
type=int,
|
||||
default=16,
|
||||
help="# of encoder filters in the first conv layer",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--n_clusters", type=int, default=10, help="number of clusters for features"
|
||||
)
|
||||
|
||||
# diy
|
||||
self.parser.add_argument(
|
||||
"--self_gen", action="store_true", help="self generate"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--mapping_n_block",
|
||||
type=int,
|
||||
default=3,
|
||||
help="number of resblock in mapping",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--map_mc", type=int, default=64, help="max channel of mapping"
|
||||
)
|
||||
self.parser.add_argument("--kl", type=float, default=0, help="KL Loss")
|
||||
self.parser.add_argument(
|
||||
"--load_pretrainA",
|
||||
type=str,
|
||||
default="",
|
||||
help="load the pretrained model from the specified location",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--load_pretrainB",
|
||||
type=str,
|
||||
default="",
|
||||
help="load the pretrained model from the specified location",
|
||||
)
|
||||
self.parser.add_argument("--feat_gan", action="store_true")
|
||||
self.parser.add_argument("--no_cgan", action="store_true")
|
||||
self.parser.add_argument("--map_unet", action="store_true")
|
||||
self.parser.add_argument("--map_densenet", action="store_true")
|
||||
self.parser.add_argument("--fcn", action="store_true")
|
||||
self.parser.add_argument(
|
||||
"--is_image", action="store_true", help="train image recon only pair data"
|
||||
)
|
||||
self.parser.add_argument("--label_unpair", action="store_true")
|
||||
self.parser.add_argument("--mapping_unpair", action="store_true")
|
||||
self.parser.add_argument("--unpair_w", type=float, default=1.0)
|
||||
self.parser.add_argument("--pair_num", type=int, default=-1)
|
||||
self.parser.add_argument("--Gan_w", type=float, default=1)
|
||||
self.parser.add_argument("--feat_dim", type=int, default=-1)
|
||||
self.parser.add_argument("--abalation_vae_len", type=int, default=-1)
|
||||
|
||||
######################### useless, just to cooperate with docker
|
||||
self.parser.add_argument("--gpu", type=str)
|
||||
self.parser.add_argument("--dataDir", type=str)
|
||||
self.parser.add_argument("--modelDir", type=str)
|
||||
self.parser.add_argument("--logDir", type=str)
|
||||
self.parser.add_argument("--data_dir", type=str)
|
||||
|
||||
self.parser.add_argument("--use_skip_model", action="store_true")
|
||||
self.parser.add_argument("--use_segmentation_model", action="store_true")
|
||||
|
||||
self.parser.add_argument("--spatio_size", type=int, default=64)
|
||||
self.parser.add_argument("--test_random_crop", action="store_true")
|
||||
##########################
|
||||
|
||||
self.parser.add_argument("--contain_scratch_L", action="store_true")
|
||||
self.parser.add_argument(
|
||||
"--mask_dilation", type=int, default=0
|
||||
) ## Don't change the input, only dilation the mask
|
||||
|
||||
self.parser.add_argument(
|
||||
"--irregular_mask",
|
||||
type=str,
|
||||
default="",
|
||||
help="This is the root of the mask",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--mapping_net_dilation",
|
||||
type=int,
|
||||
default=1,
|
||||
help="This parameter is the dilation size of the translation net",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--VOC",
|
||||
type=str,
|
||||
default="VOC_RGB_JPEGImages.bigfile",
|
||||
help="The root of VOC dataset",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--non_local", type=str, default="", help="which non_local setting"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--NL_fusion_method",
|
||||
type=str,
|
||||
default="add",
|
||||
help="how to fuse the origin feature and nl feature",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--NL_use_mask",
|
||||
action="store_true",
|
||||
help="If use mask while using Non-local mapping model",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--correlation_renormalize",
|
||||
action="store_true",
|
||||
help="Since after mask out the correlation matrix(which is softmaxed), the sum is not 1 any more, enable this param to re-weight",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--Smooth_L1", action="store_true", help="Use L1 Loss in image level"
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--face_restore_setting",
|
||||
type=int,
|
||||
default=1,
|
||||
help="This is for the aligned face restoration",
|
||||
)
|
||||
self.parser.add_argument("--face_clean_url", type=str, default="")
|
||||
self.parser.add_argument("--syn_input_url", type=str, default="")
|
||||
self.parser.add_argument("--syn_gt_url", type=str, default="")
|
||||
|
||||
self.parser.add_argument(
|
||||
"--test_on_synthetic",
|
||||
action="store_true",
|
||||
help="If you want to test on the synthetic data, enable this parameter",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--use_SN", action="store_true", help="Add SN to every parametric layer"
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--use_two_stage_mapping",
|
||||
action="store_true",
|
||||
help="choose the model which uses two stage",
|
||||
)
|
||||
|
||||
self.parser.add_argument("--L1_weight", type=float, default=10.0)
|
||||
self.parser.add_argument("--softmax_temperature", type=float, default=1.0)
|
||||
self.parser.add_argument(
|
||||
"--patch_similarity",
|
||||
action="store_true",
|
||||
help="Enable this denotes using 3*3 patch to calculate similarity",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--use_self",
|
||||
action="store_true",
|
||||
help="Enable this denotes that while constructing the new feature maps, using original feature (diagonal == 1)",
|
||||
)
|
||||
|
||||
self.parser.add_argument("--use_own_dataset", action="store_true")
|
||||
|
||||
self.parser.add_argument(
|
||||
"--test_hole_two_folders",
|
||||
action="store_true",
|
||||
help="Enable this parameter means test the restoration with inpainting given twp folders which are mask and old respectively",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--no_hole",
|
||||
action="store_true",
|
||||
help="While test the full_model on non_scratch data, do not add random mask into the real old photos",
|
||||
) ## Only for testing
|
||||
self.parser.add_argument(
|
||||
"--random_hole",
|
||||
action="store_true",
|
||||
help="While training the full model, 50% probability add hole",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--NL_res", action="store_true", help="NL+Resdual Block"
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--image_L1", action="store_true", help="Image level loss: L1"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--hole_image_no_mask",
|
||||
action="store_true",
|
||||
help="while testing, give hole image but not give the mask",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--down_sample_degradation",
|
||||
action="store_true",
|
||||
help="down_sample the image only, corresponds to [down_sample_face]",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--norm_G",
|
||||
type=str,
|
||||
default="spectralinstance",
|
||||
help="The norm type of Generator",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--init_G",
|
||||
type=str,
|
||||
default="xavier",
|
||||
help="normal|xavier|xavier_uniform|kaiming|orthogonal|none",
|
||||
)
|
||||
|
||||
self.parser.add_argument("--use_new_G", action="store_true")
|
||||
self.parser.add_argument("--use_new_D", action="store_true")
|
||||
|
||||
self.parser.add_argument(
|
||||
"--only_voc",
|
||||
action="store_true",
|
||||
help="test the trianed celebA face model using VOC face",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--cosin_similarity",
|
||||
action="store_true",
|
||||
help="For non-local, using cosin to calculate the similarity",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--downsample_mode",
|
||||
type=str,
|
||||
default="nearest",
|
||||
help="For partial non-local, choose how to downsample the mask",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--mapping_exp",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Default 0: original PNL|1: Multi-Scale Patch Attention",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--inference_optimize", action="store_true", help="optimize the memory cost"
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def parse(self, custom_args: list):
|
||||
|
||||
assert len(custom_args) > 0, "Manually Pass Arguments!"
|
||||
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
|
||||
self.opt = self.parser.parse_args(custom_args)
|
||||
self.opt.isTrain = self.isTrain # train or test
|
||||
|
||||
str_ids = self.opt.gpu_ids.split(",")
|
||||
self.opt.gpu_ids = []
|
||||
for str_id in str_ids:
|
||||
int_id = int(str_id)
|
||||
if int_id >= 0:
|
||||
self.opt.gpu_ids.append(int_id)
|
||||
|
||||
# set gpu ids
|
||||
if torch.cuda.is_available() and len(self.opt.gpu_ids) > 0:
|
||||
if torch.cuda.device_count() > self.opt.gpu_ids[0]:
|
||||
try:
|
||||
torch.cuda.set_device(self.opt.gpu_ids[0])
|
||||
except:
|
||||
print("Failed to set GPU device. Using CPU...")
|
||||
|
||||
else:
|
||||
print("Invalid GPU ID. Using CPU...")
|
||||
|
||||
return self.opt
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TestOptions(BaseOptions):
|
||||
def initialize(self):
|
||||
BaseOptions.initialize(self)
|
||||
self.parser.add_argument("--ntest", type=int, default=float("inf"), help="# of test examples.")
|
||||
self.parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.")
|
||||
self.parser.add_argument(
|
||||
"--aspect_ratio", type=float, default=1.0, help="aspect ratio of result images"
|
||||
)
|
||||
self.parser.add_argument("--phase", type=str, default="test", help="train, val, test, etc")
|
||||
self.parser.add_argument(
|
||||
"--which_epoch",
|
||||
type=str,
|
||||
default="latest",
|
||||
help="which epoch to load? set to latest to use latest cached model",
|
||||
)
|
||||
self.parser.add_argument("--how_many", type=int, default=50, help="how many test images to run")
|
||||
self.parser.add_argument(
|
||||
"--cluster_path",
|
||||
type=str,
|
||||
default="features_clustered_010.npy",
|
||||
help="the path for clustered results of encoded features",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--use_encoded_image",
|
||||
action="store_true",
|
||||
help="if specified, encode the real image to get the feature map",
|
||||
)
|
||||
self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
|
||||
self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
|
||||
self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
|
||||
self.parser.add_argument(
|
||||
"--start_epoch",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="write the start_epoch of iter.txt into this parameter",
|
||||
)
|
||||
|
||||
self.parser.add_argument("--test_dataset", type=str, default="Real_RGB_old.bigfile")
|
||||
self.parser.add_argument(
|
||||
"--no_degradation",
|
||||
action="store_true",
|
||||
help="when train the mapping, enable this parameter --> no degradation will be added into clean image",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--no_load_VAE",
|
||||
action="store_true",
|
||||
help="when train the mapping, enable this parameter --> random initialize the encoder an decoder",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--use_v2_degradation",
|
||||
action="store_true",
|
||||
help="enable this parameter --> 4 kinds of degradations will be used to synthesize corruption",
|
||||
)
|
||||
self.parser.add_argument("--use_vae_which_epoch", type=str, default="latest")
|
||||
self.isTrain = False
|
||||
|
||||
self.parser.add_argument("--generate_pair", action="store_true")
|
||||
|
||||
self.parser.add_argument("--multi_scale_test", type=float, default=0.5)
|
||||
self.parser.add_argument("--multi_scale_threshold", type=float, default=0.5)
|
||||
self.parser.add_argument(
|
||||
"--mask_need_scale",
|
||||
action="store_true",
|
||||
help="enable this param meas that the pixel range of mask is 0-255",
|
||||
)
|
||||
self.parser.add_argument("--scale_num", type=int, default=1)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--save_feature_url", type=str, default="", help="While extracting the features, where to put"
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--test_input", type=str, default="", help="A directory or a root of bigfile"
|
||||
)
|
||||
self.parser.add_argument("--test_mask", type=str, default="", help="A directory or a root of bigfile")
|
||||
self.parser.add_argument("--test_gt", type=str, default="", help="A directory or a root of bigfile")
|
||||
|
||||
self.parser.add_argument(
|
||||
"--scale_input", action="store_true", help="While testing, choose to scale the input firstly"
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
"--save_feature_name", type=str, default="features.json", help="The name of saved features"
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--test_rgb_old_wo_scratch", action="store_true", help="Same setting with origin test"
|
||||
)
|
||||
|
||||
self.parser.add_argument("--test_mode", type=str, default="Crop", help="Scale|Full|Crop")
|
||||
self.parser.add_argument("--Quality_restore", action="store_true", help="For RGB images")
|
||||
self.parser.add_argument(
|
||||
"--Scratch_and_Quality_restore", action="store_true", help="For scratched images"
|
||||
)
|
||||
self.parser.add_argument("--HR", action='store_true',help='Large input size with scratches')
|
||||
142
Global/test.py
142
Global/test.py
@@ -1,142 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from .models.mapping_model import Pix2PixHDModel_Mapping
|
||||
from .options.test_options import TestOptions
|
||||
|
||||
tensor2image = transforms.ToPILImage()
|
||||
|
||||
|
||||
def data_transforms(img, method=Image.BILINEAR, scale=False):
|
||||
|
||||
ow, oh = img.size
|
||||
pw, ph = ow, oh
|
||||
if scale == True:
|
||||
if ow < oh:
|
||||
ow = 256
|
||||
oh = ph / pw * 256
|
||||
else:
|
||||
oh = 256
|
||||
ow = pw / ph * 256
|
||||
|
||||
h = int(round(oh / 4) * 4)
|
||||
w = int(round(ow / 4) * 4)
|
||||
|
||||
if (h == ph) and (w == pw):
|
||||
return img
|
||||
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def data_transforms_rgb_old(img):
|
||||
w, h = img.size
|
||||
A = img
|
||||
if w < 256 or h < 256:
|
||||
A = transforms.Scale(256, Image.BILINEAR)(img)
|
||||
return transforms.CenterCrop(256)(A)
|
||||
|
||||
|
||||
def irregular_hole_synthesize(img, mask):
|
||||
|
||||
img_np = np.array(img).astype("uint8")
|
||||
mask_np = np.array(mask).astype("uint8")
|
||||
mask_np = mask_np / 255
|
||||
img_new = img_np * (1 - mask_np) + mask_np * 255
|
||||
|
||||
hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB")
|
||||
|
||||
return hole_img
|
||||
|
||||
|
||||
def parameter_set(opt, ckpt_dir):
|
||||
## Default parameters
|
||||
opt.serial_batches = True # no shuffle
|
||||
opt.no_flip = True # no flip
|
||||
opt.label_nc = 0
|
||||
opt.n_downsample_global = 3
|
||||
opt.mc = 64
|
||||
opt.k_size = 4
|
||||
opt.start_r = 1
|
||||
opt.mapping_n_block = 6
|
||||
opt.map_mc = 512
|
||||
opt.no_instance = True
|
||||
opt.checkpoints_dir = ckpt_dir
|
||||
|
||||
if opt.Quality_restore:
|
||||
opt.name = "mapping_quality"
|
||||
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality")
|
||||
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality")
|
||||
if opt.Scratch_and_Quality_restore:
|
||||
opt.NL_res = True
|
||||
opt.use_SN = True
|
||||
opt.correlation_renormalize = True
|
||||
opt.NL_use_mask = True
|
||||
opt.NL_fusion_method = "combine"
|
||||
opt.non_local = "Setting_42"
|
||||
opt.name = "mapping_scratch"
|
||||
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality")
|
||||
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch")
|
||||
if opt.HR:
|
||||
opt.mapping_exp = 1
|
||||
opt.inference_optimize = True
|
||||
opt.mask_dilation = 3
|
||||
opt.name = "mapping_Patch_Attention"
|
||||
|
||||
|
||||
def global_test(
|
||||
ckpt_dir: str, custom_args: list, input_image: Image, mask: Image = None
|
||||
) -> Image:
|
||||
|
||||
opt = TestOptions().parse(custom_args)
|
||||
|
||||
parameter_set(opt, ckpt_dir)
|
||||
|
||||
model = Pix2PixHDModel_Mapping()
|
||||
|
||||
model.initialize(opt)
|
||||
model.eval()
|
||||
|
||||
img_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
||||
)
|
||||
mask_transform = transforms.ToTensor()
|
||||
|
||||
print("processing...")
|
||||
|
||||
if opt.NL_use_mask:
|
||||
if opt.mask_dilation != 0:
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
mask = np.array(mask)
|
||||
mask = cv2.dilate(mask, kernel, iterations=opt.mask_dilation)
|
||||
mask = Image.fromarray(mask.astype("uint8"))
|
||||
|
||||
input_image = irregular_hole_synthesize(input_image, mask)
|
||||
mask = mask_transform(mask)
|
||||
mask = mask[:1, :, :] ## Convert to single channel
|
||||
mask = mask.unsqueeze(0)
|
||||
input_image = img_transform(input_image)
|
||||
input_image = input_image.unsqueeze(0)
|
||||
|
||||
else:
|
||||
if opt.test_mode == "Scale":
|
||||
input_image = data_transforms(input_image, scale=True)
|
||||
elif opt.test_mode == "Full":
|
||||
input_image = data_transforms(input_image, scale=False)
|
||||
elif opt.test_mode == "Crop":
|
||||
input_image = data_transforms_rgb_old(input_image)
|
||||
|
||||
input_image = img_transform(input_image)
|
||||
input_image = input_image.unsqueeze(0)
|
||||
mask = torch.zeros_like(input_image)
|
||||
|
||||
with torch.no_grad():
|
||||
generated = model.inference(input_image, mask)
|
||||
|
||||
restored = torch.clamp((generated.data.cpu() + 1.0) / 2.0, 0.0, 1.0) * 255
|
||||
return tensor2image(restored[0].byte())
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import random
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class ImagePool:
|
||||
def __init__(self, pool_size):
|
||||
self.pool_size = pool_size
|
||||
if self.pool_size > 0:
|
||||
self.num_imgs = 0
|
||||
self.images = []
|
||||
|
||||
def query(self, images):
|
||||
if self.pool_size == 0:
|
||||
return images
|
||||
return_images = []
|
||||
for image in images.data:
|
||||
image = torch.unsqueeze(image, 0)
|
||||
if self.num_imgs < self.pool_size:
|
||||
self.num_imgs = self.num_imgs + 1
|
||||
self.images.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
p = random.uniform(0, 1)
|
||||
if p > 0.5:
|
||||
random_id = random.randint(0, self.pool_size - 1)
|
||||
tmp = self.images[random_id].clone()
|
||||
self.images[random_id] = image
|
||||
return_images.append(tmp)
|
||||
else:
|
||||
return_images.append(image)
|
||||
return_images = Variable(torch.cat(return_images, 0))
|
||||
return return_images
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import os
|
||||
import torch.nn as nn
|
||||
|
||||
# Converts a Tensor into a Numpy array
|
||||
# |imtype|: the desired type of the converted numpy array
|
||||
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
|
||||
if isinstance(image_tensor, list):
|
||||
image_numpy = []
|
||||
for i in range(len(image_tensor)):
|
||||
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
|
||||
return image_numpy
|
||||
image_numpy = image_tensor.cpu().float().numpy()
|
||||
if normalize:
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
else:
|
||||
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
||||
image_numpy = np.clip(image_numpy, 0, 255)
|
||||
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
|
||||
image_numpy = image_numpy[:, :, 0]
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
|
||||
# Converts a one-hot tensor into a colorful label map
|
||||
def tensor2label(label_tensor, n_label, imtype=np.uint8):
|
||||
if n_label == 0:
|
||||
return tensor2im(label_tensor, imtype)
|
||||
label_tensor = label_tensor.cpu().float()
|
||||
if label_tensor.size()[0] > 1:
|
||||
label_tensor = label_tensor.max(0, keepdim=True)[1]
|
||||
label_tensor = Colorize(n_label)(label_tensor)
|
||||
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
|
||||
return label_numpy.astype(imtype)
|
||||
|
||||
|
||||
def save_image(image_numpy, image_path):
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil.save(image_path)
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
@@ -14,11 +14,11 @@ This is an Extension for the [Automatic1111 Webui](https://github.com/AUTOMATIC1
|
||||
## Requirements
|
||||
0. Install this Extension
|
||||
1. Download `global_checkpoints.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases)
|
||||
2. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/Global`
|
||||
2. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Global`
|
||||
3. Download `face_checkpoints.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases)
|
||||
4. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/Face_Enhancement`
|
||||
4. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Face_Enhancement`
|
||||
5. Download `shape_predictor_68_face_landmarks.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases)
|
||||
6. Extract the `.dat` **file** into `~webui/extensions/sd-webui-old-photo-restoration/Face_Detection`
|
||||
6. Extract the `.dat` **file** into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Face_Detection`
|
||||
|
||||
> The [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases) page includes the original links, as well as the backups mirrored by myself
|
||||
|
||||
@@ -40,3 +40,4 @@ After installing this Extension, there will be an **Old Photo Restoration** sect
|
||||
- *(This is **different** from the Webui built-in ones)*
|
||||
- **High Resolution:** Use higher parameters to do the processing
|
||||
- *(Only has an effect when either `Process Scratch` or `Face Restore` is also enabled)*
|
||||
- **Use CPU:** Enable this if you do not have a Nvidia GPU or are getting **Out of Memory** Error
|
||||
|
||||
1
lib_bopb2l
Submodule
1
lib_bopb2l
Submodule
Submodule lib_bopb2l added at b5b24d5475
24
preload.py
24
preload.py
@@ -2,14 +2,28 @@ import os
|
||||
|
||||
EXTENSION_FOLDER = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
FACE = os.path.join(EXTENSION_FOLDER, "Face_Enhancement", "checkpoints")
|
||||
FACE = os.path.join(
|
||||
EXTENSION_FOLDER,
|
||||
"lib_bopb2l",
|
||||
"Face_Enhancement",
|
||||
"checkpoints",
|
||||
)
|
||||
GLOBAL = os.path.join(
|
||||
EXTENSION_FOLDER,
|
||||
"lib_bopb2l",
|
||||
"Global",
|
||||
"checkpoints",
|
||||
)
|
||||
MDL = os.path.join(
|
||||
EXTENSION_FOLDER,
|
||||
"lib_bopb2l",
|
||||
"Face_Detection",
|
||||
"shape_predictor_68_face_landmarks.dat",
|
||||
)
|
||||
|
||||
if not os.path.exists(FACE):
|
||||
print("\n[Warning] face_checkpoints not detected! Please download it from Release!")
|
||||
|
||||
GLOBAL = os.path.join(EXTENSION_FOLDER, "Global", "checkpoints")
|
||||
if not os.path.exists(GLOBAL):
|
||||
print("[Warning] global_checkpoints not detected! Please download it from Release!")
|
||||
|
||||
MDL = os.path.join(EXTENSION_FOLDER, "Face_Detection", "shape_predictor_68_face_landmarks.dat")
|
||||
if not os.path.exists(MDL):
|
||||
print("[Warning] face_landmarks not detected! Please download it from Release!\n")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
scikit-image
|
||||
easydict
|
||||
PyYAML
|
||||
dominate
|
||||
dill
|
||||
tensorboardX
|
||||
scipy
|
||||
opencv-python
|
||||
dominate
|
||||
easydict
|
||||
einops
|
||||
matplotlib
|
||||
opencv-python
|
||||
scikit-image
|
||||
scipy
|
||||
tensorboardX
|
||||
|
||||
@@ -5,12 +5,12 @@ from modules.api import api
|
||||
from fastapi import FastAPI, Body
|
||||
import gradio as gr
|
||||
|
||||
from scripts.main_function import main
|
||||
from scripts.bopb2l_main import main
|
||||
|
||||
|
||||
def bop_api(_: gr.Blocks, app: FastAPI):
|
||||
|
||||
@app.post("/bop")
|
||||
@app.post("/bopb2l/restore")
|
||||
async def old_photo_restoration(
|
||||
image: str = Body("", title="input image"),
|
||||
scratch: bool = Body(False, title="process scratch"),
|
||||
@@ -22,9 +22,7 @@ def bop_api(_: gr.Blocks, app: FastAPI):
|
||||
return
|
||||
|
||||
input_image = api.decode_base64_to_image(image)
|
||||
|
||||
img = main(input_image, scratch, hr, face_res, cpu)
|
||||
|
||||
return {"image": api.encode_pil_to_base64(img).decode("utf-8")}
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from modules import scripts_postprocessing, ui_components
|
||||
from modules.ui_components import InputAccordion
|
||||
from modules import scripts_postprocessing
|
||||
import gradio as gr
|
||||
|
||||
from scripts.main_function import main
|
||||
from scripts.bopb2l_main import main
|
||||
|
||||
|
||||
class OldPhotoRestoration(scripts_postprocessing.ScriptPostprocessing):
|
||||
@@ -9,22 +10,18 @@ class OldPhotoRestoration(scripts_postprocessing.ScriptPostprocessing):
|
||||
order = 200409484
|
||||
|
||||
def ui(self):
|
||||
with ui_components.InputAccordion(
|
||||
False, label="Old Photo Restoration"
|
||||
) as enable:
|
||||
|
||||
with InputAccordion(False, label="Old Photo Restoration") as enable:
|
||||
proc_order = gr.Radio(
|
||||
["Restoration First", "Upscale First"],
|
||||
label="Processing Order",
|
||||
choices=("Restoration First", "Upscale First"),
|
||||
value="Restoration First",
|
||||
label="Processing Order",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
do_scratch = gr.Checkbox(label="Process Scratch")
|
||||
do_face_res = gr.Checkbox(label="Face Restore")
|
||||
|
||||
do_scratch = gr.Checkbox(False, label="Process Scratch")
|
||||
do_face_res = gr.Checkbox(False, label="Face Restore")
|
||||
with gr.Row():
|
||||
is_hr = gr.Checkbox(label="High Resolution")
|
||||
is_hr = gr.Checkbox(False, label="High Resolution")
|
||||
use_cpu = gr.Checkbox(True, label="Use CPU")
|
||||
|
||||
args = {
|
||||
@@ -1,12 +1,12 @@
|
||||
from Global.test import global_test
|
||||
from Global.detection import global_detection
|
||||
from lib_bopb2l.Global.test import global_test
|
||||
from lib_bopb2l.Global.detection import global_detection
|
||||
|
||||
from Face_Detection.detect_all_dlib import detect
|
||||
from Face_Detection.detect_all_dlib_HR import detect_hr
|
||||
from Face_Detection.align_warp_back_multiple_dlib import align_warp
|
||||
from Face_Detection.align_warp_back_multiple_dlib_HR import align_warp_hr
|
||||
from lib_bopb2l.Face_Detection.detect_all_dlib import detect
|
||||
from lib_bopb2l.Face_Detection.detect_all_dlib_HR import detect_hr
|
||||
from lib_bopb2l.Face_Detection.align_warp_back_multiple_dlib import align_warp
|
||||
from lib_bopb2l.Face_Detection.align_warp_back_multiple_dlib_HR import align_warp_hr
|
||||
|
||||
from Face_Enhancement.test_face import test_face
|
||||
from lib_bopb2l.Face_Enhancement.test_face import test_face
|
||||
|
||||
from modules import scripts
|
||||
from PIL import Image
|
||||
@@ -15,11 +15,12 @@ import os
|
||||
|
||||
|
||||
GLOBAL_CHECKPOINTS_FOLDER = os.path.join(
|
||||
scripts.basedir(), "Global", "checkpoints", "restoration"
|
||||
scripts.basedir(), "lib_bopb2l", "Global", "checkpoints", "restoration"
|
||||
)
|
||||
FACE_CHECKPOINTS_FOLDER = os.path.join(
|
||||
scripts.basedir(), "Face_Enhancement", "checkpoints"
|
||||
scripts.basedir(), "lib_bopb2l", "Face_Enhancement", "checkpoints"
|
||||
)
|
||||
|
||||
FACE_ENHANCEMENT_CHECKPOINTS = ("Setting_9_epoch_100", "FaceSR_512")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user