From b8c6cf1a959adafb6e5d8d3972ec797fb1802b07 Mon Sep 17 00:00:00 2001 From: Haoming Date: Mon, 4 Nov 2024 15:53:34 +0800 Subject: [PATCH] submodule --- .gitignore | 7 - .gitmodules | 3 + .../align_warp_back_multiple_dlib.py | 430 --------- .../align_warp_back_multiple_dlib_HR.py | 430 --------- Face_Detection/detect_all_dlib.py | 153 --- Face_Detection/detect_all_dlib_HR.py | 153 --- Face_Enhancement/data/__init__.py | 24 - Face_Enhancement/data/base_dataset.py | 125 --- Face_Enhancement/data/custom_dataset.py | 56 -- Face_Enhancement/data/face_dataset.py | 72 -- Face_Enhancement/data/image_folder.py | 101 -- Face_Enhancement/data/pix2pix_dataset.py | 108 --- Face_Enhancement/models/__init__.py | 45 - Face_Enhancement/models/networks/__init__.py | 57 -- .../models/networks/architecture.py | 173 ---- .../models/networks/base_network.py | 58 -- Face_Enhancement/models/networks/encoder.py | 53 -- Face_Enhancement/models/networks/generator.py | 233 ----- .../models/networks/normalization.py | 100 -- .../networks/sync_batchnorm/__init__.py | 14 - .../networks/sync_batchnorm/batchnorm.py | 412 --------- .../sync_batchnorm/batchnorm_reimpl.py | 74 -- .../models/networks/sync_batchnorm/comm.py | 137 --- .../networks/sync_batchnorm/replicate.py | 94 -- .../networks/sync_batchnorm/unittest.py | 29 - Face_Enhancement/models/pix2pix_model.py | 246 ----- Face_Enhancement/options/__init__.py | 2 - Face_Enhancement/options/base_options.py | 347 ------- Face_Enhancement/options/test_options.py | 26 - Face_Enhancement/test_face.py | 34 - Face_Enhancement/util/__init__.py | 2 - Face_Enhancement/util/iter_counter.py | 74 -- Face_Enhancement/util/util.py | 217 ----- Face_Enhancement/util/visualizer.py | 134 --- Global/data/Create_Bigfile.py | 63 -- Global/data/Load_Bigfile.py | 42 - Global/data/__init__.py | 0 Global/data/base_data_loader.py | 16 - Global/data/base_dataset.py | 114 --- Global/data/custom_dataset_data_loader.py | 41 - Global/data/data_loader.py | 9 - Global/data/image_folder.py | 62 -- Global/detection.py | 144 --- Global/detection_models/__init__.py | 0 Global/detection_models/antialiasing.py | 70 -- Global/detection_models/networks.py | 332 ------- .../sync_batchnorm/__init__.py | 14 - .../sync_batchnorm/batchnorm.py | 412 --------- .../sync_batchnorm/batchnorm_reimpl.py | 74 -- .../detection_models/sync_batchnorm/comm.py | 137 --- .../sync_batchnorm/replicate.py | 94 -- .../sync_batchnorm/unittest.py | 29 - Global/detection_util/util.py | 245 ----- .../models/NonLocal_feature_mapping_model.py | 204 ---- Global/models/__init__.py | 0 Global/models/base_model.py | 122 --- Global/models/mapping_model.py | 453 --------- Global/models/models.py | 39 - Global/models/networks.py | 875 ------------------ Global/models/pix2pixHD_model.py | 333 ------- Global/models/pix2pixHD_model_DA.py | 372 -------- Global/options/__init__.py | 0 Global/options/base_options.py | 469 ---------- Global/options/test_options.py | 100 -- Global/test.py | 142 --- Global/util/__init__.py | 0 Global/util/image_pool.py | 36 - Global/util/util.py | 58 -- LICENSE_MICROSOFT | 21 - README.md | 7 +- lib_bopb2l | 1 + preload.py | 24 +- requirements.txt | 12 +- scripts/bop_api.py | 6 +- scripts/{bop.py => bopb2l.py} | 21 +- scripts/{main_function.py => bopb2l_main.py} | 19 +- 76 files changed, 54 insertions(+), 9381 deletions(-) create mode 100644 .gitmodules delete mode 100644 Face_Detection/align_warp_back_multiple_dlib.py delete mode 100644 Face_Detection/align_warp_back_multiple_dlib_HR.py delete mode 100644 Face_Detection/detect_all_dlib.py delete mode 100644 Face_Detection/detect_all_dlib_HR.py delete mode 100644 Face_Enhancement/data/__init__.py delete mode 100644 Face_Enhancement/data/base_dataset.py delete mode 100644 Face_Enhancement/data/custom_dataset.py delete mode 100644 Face_Enhancement/data/face_dataset.py delete mode 100644 Face_Enhancement/data/image_folder.py delete mode 100644 Face_Enhancement/data/pix2pix_dataset.py delete mode 100644 Face_Enhancement/models/__init__.py delete mode 100644 Face_Enhancement/models/networks/__init__.py delete mode 100644 Face_Enhancement/models/networks/architecture.py delete mode 100644 Face_Enhancement/models/networks/base_network.py delete mode 100644 Face_Enhancement/models/networks/encoder.py delete mode 100644 Face_Enhancement/models/networks/generator.py delete mode 100644 Face_Enhancement/models/networks/normalization.py delete mode 100644 Face_Enhancement/models/networks/sync_batchnorm/__init__.py delete mode 100644 Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py delete mode 100644 Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py delete mode 100644 Face_Enhancement/models/networks/sync_batchnorm/comm.py delete mode 100644 Face_Enhancement/models/networks/sync_batchnorm/replicate.py delete mode 100644 Face_Enhancement/models/networks/sync_batchnorm/unittest.py delete mode 100644 Face_Enhancement/models/pix2pix_model.py delete mode 100644 Face_Enhancement/options/__init__.py delete mode 100644 Face_Enhancement/options/base_options.py delete mode 100644 Face_Enhancement/options/test_options.py delete mode 100644 Face_Enhancement/test_face.py delete mode 100644 Face_Enhancement/util/__init__.py delete mode 100644 Face_Enhancement/util/iter_counter.py delete mode 100644 Face_Enhancement/util/util.py delete mode 100644 Face_Enhancement/util/visualizer.py delete mode 100644 Global/data/Create_Bigfile.py delete mode 100644 Global/data/Load_Bigfile.py delete mode 100644 Global/data/__init__.py delete mode 100644 Global/data/base_data_loader.py delete mode 100644 Global/data/base_dataset.py delete mode 100644 Global/data/custom_dataset_data_loader.py delete mode 100644 Global/data/data_loader.py delete mode 100644 Global/data/image_folder.py delete mode 100644 Global/detection.py delete mode 100644 Global/detection_models/__init__.py delete mode 100644 Global/detection_models/antialiasing.py delete mode 100644 Global/detection_models/networks.py delete mode 100644 Global/detection_models/sync_batchnorm/__init__.py delete mode 100644 Global/detection_models/sync_batchnorm/batchnorm.py delete mode 100644 Global/detection_models/sync_batchnorm/batchnorm_reimpl.py delete mode 100644 Global/detection_models/sync_batchnorm/comm.py delete mode 100644 Global/detection_models/sync_batchnorm/replicate.py delete mode 100644 Global/detection_models/sync_batchnorm/unittest.py delete mode 100644 Global/detection_util/util.py delete mode 100644 Global/models/NonLocal_feature_mapping_model.py delete mode 100644 Global/models/__init__.py delete mode 100644 Global/models/base_model.py delete mode 100644 Global/models/mapping_model.py delete mode 100644 Global/models/models.py delete mode 100644 Global/models/networks.py delete mode 100644 Global/models/pix2pixHD_model.py delete mode 100644 Global/models/pix2pixHD_model_DA.py delete mode 100644 Global/options/__init__.py delete mode 100644 Global/options/base_options.py delete mode 100644 Global/options/test_options.py delete mode 100644 Global/test.py delete mode 100644 Global/util/__init__.py delete mode 100644 Global/util/image_pool.py delete mode 100644 Global/util/util.py delete mode 100644 LICENSE_MICROSOFT create mode 160000 lib_bopb2l rename scripts/{bop.py => bopb2l.py} (76%) rename scripts/{main_function.py => bopb2l_main.py} (83%) diff --git a/.gitignore b/.gitignore index 5ab7688..bee8a64 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1 @@ -# Junks __pycache__ -*.pyc -*~ - -# Models -*landmarks.dat -*/checkpoints/* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..1ea847d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "lib_bopb2l"] + path = lib_bopb2l + url = https://github.com/Haoming02/BOP-B2L-Backend diff --git a/Face_Detection/align_warp_back_multiple_dlib.py b/Face_Detection/align_warp_back_multiple_dlib.py deleted file mode 100644 index f7b5fc1..0000000 --- a/Face_Detection/align_warp_back_multiple_dlib.py +++ /dev/null @@ -1,430 +0,0 @@ -# Copyright (c) Microsoft Corporation - -from skimage.transform import SimilarityTransform -from matplotlib.patches import Rectangle -from PIL import Image, ImageFilter -from skimage.transform import warp -from skimage import img_as_ubyte -import matplotlib.pyplot as plt -import numpy as np -import dlib -import cv2 -import os - - -def calculate_cdf(histogram): - """ - This method calculates the cumulative distribution function - :param array histogram: The values of the histogram - :return: normalized_cdf: The normalized cumulative distribution function - :rtype: array - """ - # Get the cumulative sum of the elements - cdf = histogram.cumsum() - - # Normalize the cdf - normalized_cdf = cdf / float(cdf.max()) - - return normalized_cdf - - -def calculate_lookup(src_cdf, ref_cdf): - """ - This method creates the lookup table - :param array src_cdf: The cdf for the source image - :param array ref_cdf: The cdf for the reference image - :return: lookup_table: The lookup table - :rtype: array - """ - lookup_table = np.zeros(256) - lookup_val = 0 - for src_pixel_val in range(len(src_cdf)): - lookup_val - for ref_pixel_val in range(len(ref_cdf)): - if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: - lookup_val = ref_pixel_val - break - lookup_table[src_pixel_val] = lookup_val - return lookup_table - - -def match_histograms(src_image, ref_image): - """ - This method matches the source image histogram to the - reference signal - :param image src_image: The original source image - :param image ref_image: The reference image - :return: image_after_matching - :rtype: image (array) - """ - # Split the images into the different color channels - # b means blue, g means green and r means red - src_b, src_g, src_r = cv2.split(src_image) - ref_b, ref_g, ref_r = cv2.split(ref_image) - - # Compute the b, g, and r histograms separately - # The flatten() Numpy method returns a copy of the array c - # collapsed into one dimension. - src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) - src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) - src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) - ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) - ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) - ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) - - # Compute the normalized cdf for the source and reference image - src_cdf_blue = calculate_cdf(src_hist_blue) - src_cdf_green = calculate_cdf(src_hist_green) - src_cdf_red = calculate_cdf(src_hist_red) - ref_cdf_blue = calculate_cdf(ref_hist_blue) - ref_cdf_green = calculate_cdf(ref_hist_green) - ref_cdf_red = calculate_cdf(ref_hist_red) - - # Make a separate lookup table for each color - blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) - green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) - red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) - - # Use the lookup function to transform the colors of the original - # source image - blue_after_transform = cv2.LUT(src_b, blue_lookup_table) - green_after_transform = cv2.LUT(src_g, green_lookup_table) - red_after_transform = cv2.LUT(src_r, red_lookup_table) - - # Put the image back together - image_after_matching = cv2.merge( - [blue_after_transform, green_after_transform, red_after_transform] - ) - image_after_matching = cv2.convertScaleAbs(image_after_matching) - - return image_after_matching - - -def _standard_face_pts(): - pts = ( - np.array( - [196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], - np.float32, - ) - / 256.0 - - 1.0 - ) - - return np.reshape(pts, (5, 2)) - - -def _origin_face_pts(): - pts = np.array( - [196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], - np.float32, - ) - - return np.reshape(pts, (5, 2)) - - -def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): - - std_pts = _standard_face_pts() # [-1,1] - target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 - - # print(target_pts) - - h, w, c = img.shape - if normalize == True: - landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 - landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 - - # print(landmark) - - affine = SimilarityTransform() - - affine.estimate(target_pts, landmark) - - return affine - - -def compute_inverse_transformation_matrix( - img, landmark, normalize, target_face_scale=1.0 -): - - std_pts = _standard_face_pts() # [-1,1] - target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 - - # print(target_pts) - - h, w, c = img.shape - if normalize == True: - landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 - landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 - - # print(landmark) - - affine = SimilarityTransform() - - affine.estimate(landmark, target_pts) - - return affine - - -def show_detection(image, box, landmark): - plt.imshow(image) - print(box[2] - box[0]) - plt.gca().add_patch( - Rectangle( - (box[1], box[0]), - box[2] - box[0], - box[3] - box[1], - linewidth=1, - edgecolor="r", - facecolor="none", - ) - ) - plt.scatter(landmark[0][0], landmark[0][1]) - plt.scatter(landmark[1][0], landmark[1][1]) - plt.scatter(landmark[2][0], landmark[2][1]) - plt.scatter(landmark[3][0], landmark[3][1]) - plt.scatter(landmark[4][0], landmark[4][1]) - plt.show() - - -def affine2theta(affine, input_w, input_h, target_w, target_h): - # param = np.linalg.inv(affine) - param = affine - theta = np.zeros([2, 3]) - theta[0, 0] = param[0, 0] * input_h / target_h - theta[0, 1] = param[0, 1] * input_w / target_h - theta[0, 2] = ( - 2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w - ) / target_h - 1 - theta[1, 0] = param[1, 0] * input_h / target_w - theta[1, 1] = param[1, 1] * input_w / target_w - theta[1, 2] = ( - 2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w - ) / target_w - 1 - return theta - - -def blur_blending(im1, im2, mask): - - mask = mask * 255.0 - - kernel = np.ones((10, 10), np.uint8) - mask = cv2.erode(mask, kernel, iterations=1) - - mask = Image.fromarray(mask.astype("uint8")).convert("L") - im1 = Image.fromarray(im1.astype("uint8")) - im2 = Image.fromarray(im2.astype("uint8")) - - mask_blur = mask.filter(ImageFilter.GaussianBlur(20)) - im = Image.composite(im1, im2, mask) - - im = Image.composite(im, im2, mask_blur) - - return np.array(im) / 255.0 - - -def blur_blending_cv2(im1, im2, mask): - - mask = mask * 255.0 - - kernel = np.ones((9, 9), np.uint8) - mask = cv2.erode(mask, kernel, iterations=3) - - mask_blur = cv2.GaussianBlur(mask, (25, 25), 0) - mask_blur /= 255.0 - - im = im1 * mask_blur + (1 - mask_blur) * im2 - - im /= 255.0 - im = np.clip(im, 0.0, 1.0) - - return im - - -def Poisson_blending(im1, im2, mask): - - # mask = 1 - mask - mask = mask * 255.0 - kernel = np.ones((10, 10), np.uint8) - mask = cv2.erode(mask, kernel, iterations=1) - mask /= 255 - mask = 1 - mask - mask *= 255 - - mask = mask[:, :, 0] - width, height, channels = im1.shape - center = (int(height / 2), int(width / 2)) - result = cv2.seamlessClone( - im2.astype("uint8"), - im1.astype("uint8"), - mask.astype("uint8"), - center, - cv2.MIXED_CLONE, - ) - - return result / 255.0 - - -def Poisson_B(im1, im2, mask, center): - - mask = mask * 255.0 - - result = cv2.seamlessClone( - im2.astype("uint8"), - im1.astype("uint8"), - mask.astype("uint8"), - center, - cv2.NORMAL_CLONE, - ) - - return result / 255 - - -def seamless_clone(old_face, new_face, raw_mask): - - height, width, _ = old_face.shape - height = height // 2 - width = width // 2 - - y_indices, x_indices, _ = np.nonzero(raw_mask) - y_crop = slice(np.min(y_indices), np.max(y_indices)) - x_crop = slice(np.min(x_indices), np.max(x_indices)) - y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) - x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) - - insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") - insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") - insertion_mask[insertion_mask != 0] = 255 - prior = np.rint( - np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant") - ).astype("uint8") - # if np.sum(insertion_mask) == 0: - n_mask = insertion_mask[1:-1, 1:-1, :] - n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) - print(n_mask.shape) - x, y, w, h = cv2.boundingRect(n_mask[:, :, 0]) - if w < 4 or h < 4: - blended = prior - else: - blended = cv2.seamlessClone( - insertion, # pylint: disable=no-member - prior, - insertion_mask, - (x_center, y_center), - cv2.NORMAL_CLONE, - ) # pylint: disable=no-member - - blended = blended[height:-height, width:-width] - - return blended.astype("float32") / 255.0 - - -def get_landmark(face_landmarks, id): - part = face_landmarks.part(id) - x = part.x - y = part.y - - return (x, y) - - -def search(face_landmarks): - - x1, y1 = get_landmark(face_landmarks, 36) - x2, y2 = get_landmark(face_landmarks, 39) - x3, y3 = get_landmark(face_landmarks, 42) - x4, y4 = get_landmark(face_landmarks, 45) - - x_nose, y_nose = get_landmark(face_landmarks, 30) - - x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) - x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) - - x_left_eye = int((x1 + x2) / 2) - y_left_eye = int((y1 + y2) / 2) - x_right_eye = int((x3 + x4) / 2) - y_right_eye = int((y3 + y4) / 2) - - results = np.array( - [ - [x_left_eye, y_left_eye], - [x_right_eye, y_right_eye], - [x_nose, y_nose], - [x_left_mouth, y_left_mouth], - [x_right_mouth, y_right_mouth], - ] - ) - - return results - - -def align_warp( - original_image: Image, - restored_faces: list, -) -> Image: - - if len(restored_faces) == 0: - return original_image - - face_detector = dlib.get_frontal_face_detector() - landmark = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "shape_predictor_68_face_landmarks.dat", - ) - landmark_locator = dlib.shape_predictor(landmark) - - origin_width, origin_height = original_image.size - image = np.array(original_image) - - faces = face_detector(image) - - blended = image - for face_id in range(len(faces)): - - current_face = faces[face_id] - face_landmarks = landmark_locator(image, current_face) - current_fl = search(face_landmarks) - - forward_mask = np.ones_like(image).astype("uint8") - affine = compute_transformation_matrix( - image, current_fl, False, target_face_scale=1.3 - ) - aligned_face = warp( - image, affine, output_shape=(256, 256, 3), preserve_range=True - ) - forward_mask = warp( - forward_mask, - affine, - output_shape=(256, 256, 3), - order=0, - preserve_range=True, - ) - - affine_inverse = affine.inverse - cur_face = np.array(restored_faces[face_id]) - - ## Histogram Color matching - A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR) - B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR) - B = match_histograms(B, A) - cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB) - - warped_back = warp( - cur_face, - affine_inverse, - output_shape=(origin_height, origin_width, 3), - order=3, - preserve_range=True, - ) - - backward_mask = warp( - forward_mask, - affine_inverse, - output_shape=(origin_height, origin_width, 3), - order=0, - preserve_range=True, - ) ## Nearest neighbour - - blended = blur_blending_cv2(warped_back, blended, backward_mask) - blended *= 255.0 - - return Image.fromarray(img_as_ubyte(blended / 255.0)) diff --git a/Face_Detection/align_warp_back_multiple_dlib_HR.py b/Face_Detection/align_warp_back_multiple_dlib_HR.py deleted file mode 100644 index 8b17721..0000000 --- a/Face_Detection/align_warp_back_multiple_dlib_HR.py +++ /dev/null @@ -1,430 +0,0 @@ -# Copyright (c) Microsoft Corporation - -from skimage.transform import SimilarityTransform -from matplotlib.patches import Rectangle -from PIL import Image, ImageFilter -from skimage.transform import warp -from skimage import img_as_ubyte -import matplotlib.pyplot as plt -import numpy as np -import dlib -import cv2 -import os - - -def calculate_cdf(histogram): - """ - This method calculates the cumulative distribution function - :param array histogram: The values of the histogram - :return: normalized_cdf: The normalized cumulative distribution function - :rtype: array - """ - # Get the cumulative sum of the elements - cdf = histogram.cumsum() - - # Normalize the cdf - normalized_cdf = cdf / float(cdf.max()) - - return normalized_cdf - - -def calculate_lookup(src_cdf, ref_cdf): - """ - This method creates the lookup table - :param array src_cdf: The cdf for the source image - :param array ref_cdf: The cdf for the reference image - :return: lookup_table: The lookup table - :rtype: array - """ - lookup_table = np.zeros(256) - lookup_val = 0 - for src_pixel_val in range(len(src_cdf)): - lookup_val - for ref_pixel_val in range(len(ref_cdf)): - if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: - lookup_val = ref_pixel_val - break - lookup_table[src_pixel_val] = lookup_val - return lookup_table - - -def match_histograms(src_image, ref_image): - """ - This method matches the source image histogram to the - reference signal - :param image src_image: The original source image - :param image ref_image: The reference image - :return: image_after_matching - :rtype: image (array) - """ - # Split the images into the different color channels - # b means blue, g means green and r means red - src_b, src_g, src_r = cv2.split(src_image) - ref_b, ref_g, ref_r = cv2.split(ref_image) - - # Compute the b, g, and r histograms separately - # The flatten() Numpy method returns a copy of the array c - # collapsed into one dimension. - src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) - src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) - src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) - ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) - ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) - ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) - - # Compute the normalized cdf for the source and reference image - src_cdf_blue = calculate_cdf(src_hist_blue) - src_cdf_green = calculate_cdf(src_hist_green) - src_cdf_red = calculate_cdf(src_hist_red) - ref_cdf_blue = calculate_cdf(ref_hist_blue) - ref_cdf_green = calculate_cdf(ref_hist_green) - ref_cdf_red = calculate_cdf(ref_hist_red) - - # Make a separate lookup table for each color - blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) - green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) - red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) - - # Use the lookup function to transform the colors of the original - # source image - blue_after_transform = cv2.LUT(src_b, blue_lookup_table) - green_after_transform = cv2.LUT(src_g, green_lookup_table) - red_after_transform = cv2.LUT(src_r, red_lookup_table) - - # Put the image back together - image_after_matching = cv2.merge( - [blue_after_transform, green_after_transform, red_after_transform] - ) - image_after_matching = cv2.convertScaleAbs(image_after_matching) - - return image_after_matching - - -def _standard_face_pts(): - pts = ( - np.array( - [196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], - np.float32, - ) - / 256.0 - - 1.0 - ) - - return np.reshape(pts, (5, 2)) - - -def _origin_face_pts(): - pts = np.array( - [196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], - np.float32, - ) - - return np.reshape(pts, (5, 2)) - - -def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): - - std_pts = _standard_face_pts() # [-1,1] - target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 - - # print(target_pts) - - h, w, c = img.shape - if normalize == True: - landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 - landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 - - # print(landmark) - - affine = SimilarityTransform() - - affine.estimate(target_pts, landmark) - - return affine - - -def compute_inverse_transformation_matrix( - img, landmark, normalize, target_face_scale=1.0 -): - - std_pts = _standard_face_pts() # [-1,1] - target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 - - # print(target_pts) - - h, w, c = img.shape - if normalize == True: - landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 - landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 - - # print(landmark) - - affine = SimilarityTransform() - - affine.estimate(landmark, target_pts) - - return affine - - -def show_detection(image, box, landmark): - plt.imshow(image) - print(box[2] - box[0]) - plt.gca().add_patch( - Rectangle( - (box[1], box[0]), - box[2] - box[0], - box[3] - box[1], - linewidth=1, - edgecolor="r", - facecolor="none", - ) - ) - plt.scatter(landmark[0][0], landmark[0][1]) - plt.scatter(landmark[1][0], landmark[1][1]) - plt.scatter(landmark[2][0], landmark[2][1]) - plt.scatter(landmark[3][0], landmark[3][1]) - plt.scatter(landmark[4][0], landmark[4][1]) - plt.show() - - -def affine2theta(affine, input_w, input_h, target_w, target_h): - # param = np.linalg.inv(affine) - param = affine - theta = np.zeros([2, 3]) - theta[0, 0] = param[0, 0] * input_h / target_h - theta[0, 1] = param[0, 1] * input_w / target_h - theta[0, 2] = ( - 2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w - ) / target_h - 1 - theta[1, 0] = param[1, 0] * input_h / target_w - theta[1, 1] = param[1, 1] * input_w / target_w - theta[1, 2] = ( - 2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w - ) / target_w - 1 - return theta - - -def blur_blending(im1, im2, mask): - - mask = mask * 255.0 - - kernel = np.ones((10, 10), np.uint8) - mask = cv2.erode(mask, kernel, iterations=1) - - mask = Image.fromarray(mask.astype("uint8")).convert("L") - im1 = Image.fromarray(im1.astype("uint8")) - im2 = Image.fromarray(im2.astype("uint8")) - - mask_blur = mask.filter(ImageFilter.GaussianBlur(20)) - im = Image.composite(im1, im2, mask) - - im = Image.composite(im, im2, mask_blur) - - return np.array(im) / 255.0 - - -def blur_blending_cv2(im1, im2, mask): - - mask = mask * 255.0 - - kernel = np.ones((9, 9), np.uint8) - mask = cv2.erode(mask, kernel, iterations=3) - - mask_blur = cv2.GaussianBlur(mask, (25, 25), 0) - mask_blur /= 255.0 - - im = im1 * mask_blur + (1 - mask_blur) * im2 - - im /= 255.0 - im = np.clip(im, 0.0, 1.0) - - return im - - -def Poisson_blending(im1, im2, mask): - - # mask = 1 - mask - mask = mask * 255.0 - kernel = np.ones((10, 10), np.uint8) - mask = cv2.erode(mask, kernel, iterations=1) - mask /= 255 - mask = 1 - mask - mask *= 255 - - mask = mask[:, :, 0] - width, height, channels = im1.shape - center = (int(height / 2), int(width / 2)) - result = cv2.seamlessClone( - im2.astype("uint8"), - im1.astype("uint8"), - mask.astype("uint8"), - center, - cv2.MIXED_CLONE, - ) - - return result / 255.0 - - -def Poisson_B(im1, im2, mask, center): - - mask = mask * 255.0 - - result = cv2.seamlessClone( - im2.astype("uint8"), - im1.astype("uint8"), - mask.astype("uint8"), - center, - cv2.NORMAL_CLONE, - ) - - return result / 255 - - -def seamless_clone(old_face, new_face, raw_mask): - - height, width, _ = old_face.shape - height = height // 2 - width = width // 2 - - y_indices, x_indices, _ = np.nonzero(raw_mask) - y_crop = slice(np.min(y_indices), np.max(y_indices)) - x_crop = slice(np.min(x_indices), np.max(x_indices)) - y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) - x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) - - insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") - insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") - insertion_mask[insertion_mask != 0] = 255 - prior = np.rint( - np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant") - ).astype("uint8") - # if np.sum(insertion_mask) == 0: - n_mask = insertion_mask[1:-1, 1:-1, :] - n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) - print(n_mask.shape) - x, y, w, h = cv2.boundingRect(n_mask[:, :, 0]) - if w < 4 or h < 4: - blended = prior - else: - blended = cv2.seamlessClone( - insertion, # pylint: disable=no-member - prior, - insertion_mask, - (x_center, y_center), - cv2.NORMAL_CLONE, - ) # pylint: disable=no-member - - blended = blended[height:-height, width:-width] - - return blended.astype("float32") / 255.0 - - -def get_landmark(face_landmarks, id): - part = face_landmarks.part(id) - x = part.x - y = part.y - - return (x, y) - - -def search(face_landmarks): - - x1, y1 = get_landmark(face_landmarks, 36) - x2, y2 = get_landmark(face_landmarks, 39) - x3, y3 = get_landmark(face_landmarks, 42) - x4, y4 = get_landmark(face_landmarks, 45) - - x_nose, y_nose = get_landmark(face_landmarks, 30) - - x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) - x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) - - x_left_eye = int((x1 + x2) / 2) - y_left_eye = int((y1 + y2) / 2) - x_right_eye = int((x3 + x4) / 2) - y_right_eye = int((y3 + y4) / 2) - - results = np.array( - [ - [x_left_eye, y_left_eye], - [x_right_eye, y_right_eye], - [x_nose, y_nose], - [x_left_mouth, y_left_mouth], - [x_right_mouth, y_right_mouth], - ] - ) - - return results - - -def align_warp_hr( - original_image: Image, - restored_faces: list, -) -> Image: - - if len(restored_faces) == 0: - return original_image - - face_detector = dlib.get_frontal_face_detector() - landmark = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "shape_predictor_68_face_landmarks.dat", - ) - landmark_locator = dlib.shape_predictor(landmark) - - origin_width, origin_height = original_image.size - image = np.array(original_image) - - faces = face_detector(image) - - blended = image - for face_id in range(len(faces)): - - current_face = faces[face_id] - face_landmarks = landmark_locator(image, current_face) - current_fl = search(face_landmarks) - - forward_mask = np.ones_like(image).astype("uint8") - affine = compute_transformation_matrix( - image, current_fl, False, target_face_scale=1.3 - ) - aligned_face = warp( - image, affine, output_shape=(512, 512, 3), preserve_range=True - ) - forward_mask = warp( - forward_mask, - affine, - output_shape=(512, 512, 3), - order=0, - preserve_range=True, - ) - - affine_inverse = affine.inverse - cur_face = np.array(restored_faces[face_id]) - - ## Histogram Color matching - A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR) - B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR) - B = match_histograms(B, A) - cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB) - - warped_back = warp( - cur_face, - affine_inverse, - output_shape=(origin_height, origin_width, 3), - order=3, - preserve_range=True, - ) - - backward_mask = warp( - forward_mask, - affine_inverse, - output_shape=(origin_height, origin_width, 3), - order=0, - preserve_range=True, - ) ## Nearest neighbour - - blended = blur_blending_cv2(warped_back, blended, backward_mask) - blended *= 255.0 - - return Image.fromarray(img_as_ubyte(blended / 255.0)) diff --git a/Face_Detection/detect_all_dlib.py b/Face_Detection/detect_all_dlib.py deleted file mode 100644 index 861b2e3..0000000 --- a/Face_Detection/detect_all_dlib.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Microsoft Corporation - -from skimage.transform import SimilarityTransform -from matplotlib.patches import Rectangle -from skimage.transform import warp -from skimage import img_as_ubyte -import matplotlib.pyplot as plt -from PIL import Image -import numpy as np -import dlib -import os - - -def _standard_face_pts(): - pts = ( - np.array( - [196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], - np.float32, - ) - / 256.0 - - 1.0 - ) - - return np.reshape(pts, (5, 2)) - - -def get_landmark(face_landmarks, id): - part = face_landmarks.part(id) - x = part.x - y = part.y - - return (x, y) - - -def search(face_landmarks): - - x1, y1 = get_landmark(face_landmarks, 36) - x2, y2 = get_landmark(face_landmarks, 39) - x3, y3 = get_landmark(face_landmarks, 42) - x4, y4 = get_landmark(face_landmarks, 45) - - x_nose, y_nose = get_landmark(face_landmarks, 30) - - x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) - x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) - - x_left_eye = int((x1 + x2) / 2) - y_left_eye = int((y1 + y2) / 2) - x_right_eye = int((x3 + x4) / 2) - y_right_eye = int((y3 + y4) / 2) - - results = np.array( - [ - [x_left_eye, y_left_eye], - [x_right_eye, y_right_eye], - [x_nose, y_nose], - [x_left_mouth, y_left_mouth], - [x_right_mouth, y_right_mouth], - ] - ) - - return results - - -def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): - - std_pts = _standard_face_pts() # [-1,1] - target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 - - # print(target_pts) - - h, w, c = img.shape - if normalize == True: - landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 - landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 - - # print(landmark) - - affine = SimilarityTransform() - - affine.estimate(target_pts, landmark) - - return affine.params - - -def show_detection(image, box, landmark): - plt.imshow(image) - print(box[2] - box[0]) - plt.gca().add_patch( - Rectangle( - (box[1], box[0]), - box[2] - box[0], - box[3] - box[1], - linewidth=1, - edgecolor="r", - facecolor="none", - ) - ) - plt.scatter(landmark[0][0], landmark[0][1]) - plt.scatter(landmark[1][0], landmark[1][1]) - plt.scatter(landmark[2][0], landmark[2][1]) - plt.scatter(landmark[3][0], landmark[3][1]) - plt.scatter(landmark[4][0], landmark[4][1]) - plt.show() - - -def affine2theta(affine, input_w, input_h, target_w, target_h): - # param = np.linalg.inv(affine) - param = affine - theta = np.zeros([2, 3]) - theta[0, 0] = param[0, 0] * input_h / target_h - theta[0, 1] = param[0, 1] * input_w / target_h - theta[0, 2] = ( - 2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w - ) / target_h - 1 - theta[1, 0] = param[1, 0] * input_h / target_w - theta[1, 1] = param[1, 1] * input_w / target_w - theta[1, 2] = ( - 2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w - ) / target_w - 1 - return theta - - -def detect(input_image: Image) -> list: - face_detector = dlib.get_frontal_face_detector() - detected_faces = [] - - landmark = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "shape_predictor_68_face_landmarks.dat", - ) - landmark_locator = dlib.shape_predictor(landmark) - - image = np.array(input_image) - faces = face_detector(image) - - if len(faces) == 0: - return detected_faces - - for face_id in range(len(faces)): - - current_face = faces[face_id] - face_landmarks = landmark_locator(image, current_face) - current_fl = search(face_landmarks) - - affine = compute_transformation_matrix( - image, current_fl, False, target_face_scale=1.3 - ) - - aligned_face = warp(image, affine, output_shape=(256, 256, 3)) - detected_faces.append(Image.fromarray(img_as_ubyte(aligned_face))) - - return detected_faces diff --git a/Face_Detection/detect_all_dlib_HR.py b/Face_Detection/detect_all_dlib_HR.py deleted file mode 100644 index 1c48b6e..0000000 --- a/Face_Detection/detect_all_dlib_HR.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Microsoft Corporation - -from skimage.transform import SimilarityTransform -from matplotlib.patches import Rectangle -from skimage.transform import warp -from skimage import img_as_ubyte -import matplotlib.pyplot as plt -from PIL import Image -import numpy as np -import dlib -import os - - -def _standard_face_pts(): - pts = ( - np.array( - [196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], - np.float32, - ) - / 256.0 - - 1.0 - ) - - return np.reshape(pts, (5, 2)) - - -def get_landmark(face_landmarks, id): - part = face_landmarks.part(id) - x = part.x - y = part.y - - return (x, y) - - -def search(face_landmarks): - - x1, y1 = get_landmark(face_landmarks, 36) - x2, y2 = get_landmark(face_landmarks, 39) - x3, y3 = get_landmark(face_landmarks, 42) - x4, y4 = get_landmark(face_landmarks, 45) - - x_nose, y_nose = get_landmark(face_landmarks, 30) - - x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) - x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) - - x_left_eye = int((x1 + x2) / 2) - y_left_eye = int((y1 + y2) / 2) - x_right_eye = int((x3 + x4) / 2) - y_right_eye = int((y3 + y4) / 2) - - results = np.array( - [ - [x_left_eye, y_left_eye], - [x_right_eye, y_right_eye], - [x_nose, y_nose], - [x_left_mouth, y_left_mouth], - [x_right_mouth, y_right_mouth], - ] - ) - - return results - - -def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): - - std_pts = _standard_face_pts() # [-1,1] - target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 - - # print(target_pts) - - h, w, c = img.shape - if normalize == True: - landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 - landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 - - # print(landmark) - - affine = SimilarityTransform() - - affine.estimate(target_pts, landmark) - - return affine.params - - -def show_detection(image, box, landmark): - plt.imshow(image) - print(box[2] - box[0]) - plt.gca().add_patch( - Rectangle( - (box[1], box[0]), - box[2] - box[0], - box[3] - box[1], - linewidth=1, - edgecolor="r", - facecolor="none", - ) - ) - plt.scatter(landmark[0][0], landmark[0][1]) - plt.scatter(landmark[1][0], landmark[1][1]) - plt.scatter(landmark[2][0], landmark[2][1]) - plt.scatter(landmark[3][0], landmark[3][1]) - plt.scatter(landmark[4][0], landmark[4][1]) - plt.show() - - -def affine2theta(affine, input_w, input_h, target_w, target_h): - # param = np.linalg.inv(affine) - param = affine - theta = np.zeros([2, 3]) - theta[0, 0] = param[0, 0] * input_h / target_h - theta[0, 1] = param[0, 1] * input_w / target_h - theta[0, 2] = ( - 2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w - ) / target_h - 1 - theta[1, 0] = param[1, 0] * input_h / target_w - theta[1, 1] = param[1, 1] * input_w / target_w - theta[1, 2] = ( - 2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w - ) / target_w - 1 - return theta - - -def detect_hr(input_image: Image) -> list: - face_detector = dlib.get_frontal_face_detector() - detected_faces = [] - - landmark = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "shape_predictor_68_face_landmarks.dat", - ) - landmark_locator = dlib.shape_predictor(landmark) - - image = np.array(input_image) - faces = face_detector(image) - - if len(faces) == 0: - return detected_faces - - for face_id in range(len(faces)): - - current_face = faces[face_id] - face_landmarks = landmark_locator(image, current_face) - current_fl = search(face_landmarks) - - affine = compute_transformation_matrix( - image, current_fl, False, target_face_scale=1.3 - ) - - aligned_face = warp(image, affine, output_shape=(512, 512, 3)) - detected_faces.append(Image.fromarray(img_as_ubyte(aligned_face))) - - return detected_faces diff --git a/Face_Enhancement/data/__init__.py b/Face_Enhancement/data/__init__.py deleted file mode 100644 index 4bcda98..0000000 --- a/Face_Enhancement/data/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Microsoft Corporation - -import torch.utils.data -from .face_dataset import FaceTestDataset - - -def create_dataloader(opt, faces: list): - - instance = FaceTestDataset() - instance.initialize(opt, faces) - - print( - f"Dataset [{type(instance).__name__}] with {len(instance)} faces was created..." - ) - - dataloader = torch.utils.data.DataLoader( - instance, - batch_size=opt.batchSize, - shuffle=not opt.serial_batches, - num_workers=int(opt.nThreads), - drop_last=opt.isTrain, - ) - - return dataloader diff --git a/Face_Enhancement/data/base_dataset.py b/Face_Enhancement/data/base_dataset.py deleted file mode 100644 index 57595dd..0000000 --- a/Face_Enhancement/data/base_dataset.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch.utils.data as data -from PIL import Image -import torchvision.transforms as transforms -import numpy as np -import random - - -class BaseDataset(data.Dataset): - def __init__(self): - super(BaseDataset, self).__init__() - - @staticmethod - def modify_commandline_options(parser, is_train): - return parser - - def initialize(self, opt): - pass - - -def get_params(opt, size): - w, h = size - new_h = h - new_w = w - if opt.preprocess_mode == "resize_and_crop": - new_h = new_w = opt.load_size - elif opt.preprocess_mode == "scale_width_and_crop": - new_w = opt.load_size - new_h = opt.load_size * h // w - elif opt.preprocess_mode == "scale_shortside_and_crop": - ss, ls = min(w, h), max(w, h) # shortside and longside - width_is_shorter = w == ss - ls = int(opt.load_size * ls / ss) - new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) - - x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) - y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) - - flip = random.random() > 0.5 - return {"crop_pos": (x, y), "flip": flip} - - -def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): - transform_list = [] - if "resize" in opt.preprocess_mode: - osize = [opt.load_size, opt.load_size] - transform_list.append(transforms.Resize(osize, interpolation=method)) - elif "scale_width" in opt.preprocess_mode: - transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) - elif "scale_shortside" in opt.preprocess_mode: - transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) - - if "crop" in opt.preprocess_mode: - transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size))) - - if opt.preprocess_mode == "none": - base = 32 - transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) - - if opt.preprocess_mode == "fixed": - w = opt.crop_size - h = round(opt.crop_size / opt.aspect_ratio) - transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) - - if opt.isTrain and not opt.no_flip: - transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"]))) - - if toTensor: - transform_list += [transforms.ToTensor()] - - if normalize: - transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - return transforms.Compose(transform_list) - - -def normalize(): - return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - - -def __resize(img, w, h, method=Image.BICUBIC): - return img.resize((w, h), method) - - -def __make_power_2(img, base, method=Image.BICUBIC): - ow, oh = img.size - h = int(round(oh / base) * base) - w = int(round(ow / base) * base) - if (h == oh) and (w == ow): - return img - return img.resize((w, h), method) - - -def __scale_width(img, target_width, method=Image.BICUBIC): - ow, oh = img.size - if ow == target_width: - return img - w = target_width - h = int(target_width * oh / ow) - return img.resize((w, h), method) - - -def __scale_shortside(img, target_width, method=Image.BICUBIC): - ow, oh = img.size - ss, ls = min(ow, oh), max(ow, oh) # shortside and longside - width_is_shorter = ow == ss - if ss == target_width: - return img - ls = int(target_width * ls / ss) - nw, nh = (ss, ls) if width_is_shorter else (ls, ss) - return img.resize((nw, nh), method) - - -def __crop(img, pos, size): - ow, oh = img.size - x1, y1 = pos - tw = th = size - return img.crop((x1, y1, x1 + tw, y1 + th)) - - -def __flip(img, flip): - if flip: - return img.transpose(Image.FLIP_LEFT_RIGHT) - return img diff --git a/Face_Enhancement/data/custom_dataset.py b/Face_Enhancement/data/custom_dataset.py deleted file mode 100644 index 7aa7af7..0000000 --- a/Face_Enhancement/data/custom_dataset.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .pix2pix_dataset import Pix2pixDataset -from .image_folder import make_dataset - - -class CustomDataset(Pix2pixDataset): - """ Dataset that loads images from directories - Use option --label_dir, --image_dir, --instance_dir to specify the directories. - The images in the directories are sorted in alphabetical order and paired in order. - """ - - @staticmethod - def modify_commandline_options(parser, is_train): - parser = Pix2pixDataset.modify_commandline_options(parser, is_train) - parser.set_defaults(preprocess_mode="resize_and_crop") - load_size = 286 if is_train else 256 - parser.set_defaults(load_size=load_size) - parser.set_defaults(crop_size=256) - parser.set_defaults(display_winsize=256) - parser.set_defaults(label_nc=13) - parser.set_defaults(contain_dontcare_label=False) - - parser.add_argument( - "--label_dir", type=str, required=True, help="path to the directory that contains label images" - ) - parser.add_argument( - "--image_dir", type=str, required=True, help="path to the directory that contains photo images" - ) - parser.add_argument( - "--instance_dir", - type=str, - default="", - help="path to the directory that contains instance maps. Leave black if not exists", - ) - return parser - - def get_paths(self, opt): - label_dir = opt.label_dir - label_paths = make_dataset(label_dir, recursive=False, read_cache=True) - - image_dir = opt.image_dir - image_paths = make_dataset(image_dir, recursive=False, read_cache=True) - - if len(opt.instance_dir) > 0: - instance_dir = opt.instance_dir - instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) - else: - instance_paths = [] - - assert len(label_paths) == len( - image_paths - ), "The #images in %s and %s do not match. Is there something wrong?" - - return label_paths, image_paths, instance_paths diff --git a/Face_Enhancement/data/face_dataset.py b/Face_Enhancement/data/face_dataset.py deleted file mode 100644 index 58ed1c4..0000000 --- a/Face_Enhancement/data/face_dataset.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Microsoft Corporation - -from .base_dataset import BaseDataset, get_params, get_transform -import torch - - -class FaceTestDataset(BaseDataset): - @staticmethod - def modify_commandline_options(parser, is_train): - parser.add_argument( - "--no_pairing_check", - action="store_true", - help="If specified, skip sanity check of correct label-image file pairing", - ) - # parser.set_defaults(contain_dontcare_label=False) - # parser.set_defaults(no_instance=True) - return parser - - def initialize(self, opt, faces: list): - self.opt = opt - self.images = faces # All the images - - self.parts = [ - "skin", - "hair", - "l_brow", - "r_brow", - "l_eye", - "r_eye", - "eye_g", - "l_ear", - "r_ear", - "ear_r", - "nose", - "mouth", - "u_lip", - "l_lip", - "neck", - "neck_l", - "cloth", - "hat", - ] - - size = len(self.images) - self.dataset_size = size - - def __getitem__(self, index): - params = get_params(self.opt, (-1, -1)) - - image = self.images[index].convert("RGB") - - transform_image = get_transform(self.opt, params) - image_tensor = transform_image(image) - - full_label = [] - - for _ in self.parts: - current_part = torch.zeros((self.opt.load_size, self.opt.load_size)) - full_label.append(current_part) - - full_label_tensor = torch.stack(full_label, 0) - - input_dict = { - "label": full_label_tensor, - "image": image_tensor, - "path": "", - } - - return input_dict - - def __len__(self): - return self.dataset_size diff --git a/Face_Enhancement/data/image_folder.py b/Face_Enhancement/data/image_folder.py deleted file mode 100644 index 7281eb2..0000000 --- a/Face_Enhancement/data/image_folder.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch.utils.data as data -from PIL import Image -import os - -IMG_EXTENSIONS = [ - ".jpg", - ".JPG", - ".jpeg", - ".JPEG", - ".png", - ".PNG", - ".ppm", - ".PPM", - ".bmp", - ".BMP", - ".tiff", - ".webp", -] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def make_dataset_rec(dir, images): - assert os.path.isdir(dir), "%s is not a valid directory" % dir - - for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): - for fname in fnames: - if is_image_file(fname): - path = os.path.join(root, fname) - images.append(path) - - -def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): - images = [] - - if read_cache: - possible_filelist = os.path.join(dir, "files.list") - if os.path.isfile(possible_filelist): - with open(possible_filelist, "r") as f: - images = f.read().splitlines() - return images - - if recursive: - make_dataset_rec(dir, images) - else: - assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir - - for root, dnames, fnames in sorted(os.walk(dir)): - for fname in fnames: - if is_image_file(fname): - path = os.path.join(root, fname) - images.append(path) - - if write_cache: - filelist_cache = os.path.join(dir, "files.list") - with open(filelist_cache, "w") as f: - for path in images: - f.write("%s\n" % path) - print("wrote filelist cache at %s" % filelist_cache) - - return images - - -def default_loader(path): - return Image.open(path).convert("RGB") - - -class ImageFolder(data.Dataset): - def __init__(self, root, transform=None, return_paths=False, loader=default_loader): - imgs = make_dataset(root) - if len(imgs) == 0: - raise ( - RuntimeError( - "Found 0 images in: " + root + "\n" - "Supported image extensions are: " + ",".join(IMG_EXTENSIONS) - ) - ) - - self.root = root - self.imgs = imgs - self.transform = transform - self.return_paths = return_paths - self.loader = loader - - def __getitem__(self, index): - path = self.imgs[index] - img = self.loader(path) - if self.transform is not None: - img = self.transform(img) - if self.return_paths: - return img, path - else: - return img - - def __len__(self): - return len(self.imgs) diff --git a/Face_Enhancement/data/pix2pix_dataset.py b/Face_Enhancement/data/pix2pix_dataset.py deleted file mode 100644 index f51189c..0000000 --- a/Face_Enhancement/data/pix2pix_dataset.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .base_dataset import BaseDataset, get_params, get_transform -from PIL import Image -from ..util import util -import os - - -class Pix2pixDataset(BaseDataset): - @staticmethod - def modify_commandline_options(parser, is_train): - parser.add_argument( - "--no_pairing_check", - action="store_true", - help="If specified, skip sanity check of correct label-image file pairing", - ) - return parser - - def initialize(self, opt): - self.opt = opt - - label_paths, image_paths, instance_paths = self.get_paths(opt) - - util.natural_sort(label_paths) - util.natural_sort(image_paths) - if not opt.no_instance: - util.natural_sort(instance_paths) - - label_paths = label_paths[: opt.max_dataset_size] - image_paths = image_paths[: opt.max_dataset_size] - instance_paths = instance_paths[: opt.max_dataset_size] - - if not opt.no_pairing_check: - for path1, path2 in zip(label_paths, image_paths): - assert self.paths_match(path1, path2), ( - "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." - % (path1, path2) - ) - - self.label_paths = label_paths - self.image_paths = image_paths - self.instance_paths = instance_paths - - size = len(self.label_paths) - self.dataset_size = size - - def get_paths(self, opt): - label_paths = [] - image_paths = [] - instance_paths = [] - assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" - return label_paths, image_paths, instance_paths - - def paths_match(self, path1, path2): - filename1_without_ext = os.path.splitext(os.path.basename(path1))[0] - filename2_without_ext = os.path.splitext(os.path.basename(path2))[0] - return filename1_without_ext == filename2_without_ext - - def __getitem__(self, index): - # Label Image - label_path = self.label_paths[index] - label = Image.open(label_path) - params = get_params(self.opt, label.size) - transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) - label_tensor = transform_label(label) * 255.0 - label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc - - # input image (real images) - image_path = self.image_paths[index] - assert self.paths_match( - label_path, image_path - ), "The label_path %s and image_path %s don't match." % (label_path, image_path) - image = Image.open(image_path) - image = image.convert("RGB") - - transform_image = get_transform(self.opt, params) - image_tensor = transform_image(image) - - # if using instance maps - if self.opt.no_instance: - instance_tensor = 0 - else: - instance_path = self.instance_paths[index] - instance = Image.open(instance_path) - if instance.mode == "L": - instance_tensor = transform_label(instance) * 255 - instance_tensor = instance_tensor.long() - else: - instance_tensor = transform_label(instance) - - input_dict = { - "label": label_tensor, - "instance": instance_tensor, - "image": image_tensor, - "path": image_path, - } - - # Give subclasses a chance to modify the final output - self.postprocess(input_dict) - - return input_dict - - def postprocess(self, input_dict): - return input_dict - - def __len__(self): - return self.dataset_size diff --git a/Face_Enhancement/models/__init__.py b/Face_Enhancement/models/__init__.py deleted file mode 100644 index d062296..0000000 --- a/Face_Enhancement/models/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import importlib -import torch - - -def find_model_using_name(model_name): - # Given the option --model [modelname], - # the file "models/modelname_model.py" - # will be imported. - assert model_name == 'pix2pix' - model_filename = f"{model_name}_model" - - from . import pix2pix_model as modellib - - # In the file, the class called ModelNameModel() will - # be instantiated. It has to be a subclass of torch.nn.Module, - # and it is case-insensitive. - model = None - target_model_name = model_name.replace("_", "") + "model" - for name, cls in modellib.__dict__.items(): - if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module): - model = cls - - if model is None: - raise SystemError( - "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." - % (model_filename, target_model_name) - ) - - return model - - -def get_option_setter(model_name): - model_class = find_model_using_name(model_name) - return model_class.modify_commandline_options - - -def create_model(opt): - model = find_model_using_name(opt.model) - instance = model(opt) - print("model [%s] was created" % (type(instance).__name__)) - - return instance diff --git a/Face_Enhancement/models/networks/__init__.py b/Face_Enhancement/models/networks/__init__.py deleted file mode 100644 index a75c814..0000000 --- a/Face_Enhancement/models/networks/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -from .base_network import BaseNetwork -from .generator import * -from .encoder import * -from ...util import util - - -def find_network_using_name(target_network_name, filename): - target_class_name = target_network_name + filename - module_name = "models.networks." + filename - network = util.find_class_in_module(target_class_name, module_name) - - assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network - - return network - - -def modify_commandline_options(parser, is_train): - opt, _ = parser.parse_known_args() - - netG_cls = find_network_using_name(opt.netG, "generator") - parser = netG_cls.modify_commandline_options(parser, is_train) - if is_train: - netD_cls = find_network_using_name(opt.netD, "discriminator") - parser = netD_cls.modify_commandline_options(parser, is_train) - netE_cls = find_network_using_name("conv", "encoder") - parser = netE_cls.modify_commandline_options(parser, is_train) - - return parser - - -def create_network(cls, opt): - net = cls(opt) - net.print_network() - if torch.cuda.is_available() and len(opt.gpu_ids) > 0: - net.cuda() - net.init_weights(opt.init_type, opt.init_variance) - return net - - -def define_G(opt): - netG_cls = find_network_using_name(opt.netG, "generator") - return create_network(netG_cls, opt) - - -def define_D(opt): - netD_cls = find_network_using_name(opt.netD, "discriminator") - return create_network(netD_cls, opt) - - -def define_E(opt): - # there exists only one encoder type - netE_cls = find_network_using_name("conv", "encoder") - return create_network(netE_cls, opt) diff --git a/Face_Enhancement/models/networks/architecture.py b/Face_Enhancement/models/networks/architecture.py deleted file mode 100644 index 2954ce8..0000000 --- a/Face_Enhancement/models/networks/architecture.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision -import torch.nn.utils.spectral_norm as spectral_norm -from .normalization import SPADE - - -# ResNet block that uses SPADE. -# It differs from the ResNet block of pix2pixHD in that -# it takes in the segmentation map as input, learns the skip connection if necessary, -# and applies normalization first and then convolution. -# This architecture seemed like a standard architecture for unconditional or -# class-conditional GAN architecture using residual block. -# The code was inspired from https://github.com/LMescheder/GAN_stability. -class SPADEResnetBlock(nn.Module): - def __init__(self, fin, fout, opt): - super().__init__() - # Attributes - self.learned_shortcut = fin != fout - fmiddle = min(fin, fout) - - self.opt = opt - # create conv layers - self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) - self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) - if self.learned_shortcut: - self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) - - # apply spectral norm if specified - if "spectral" in opt.norm_G: - self.conv_0 = spectral_norm(self.conv_0) - self.conv_1 = spectral_norm(self.conv_1) - if self.learned_shortcut: - self.conv_s = spectral_norm(self.conv_s) - - # define normalization layers - spade_config_str = opt.norm_G.replace("spectral", "") - self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt) - self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt) - if self.learned_shortcut: - self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt) - - # note the resnet block with SPADE also takes in |seg|, - # the semantic segmentation map as input - def forward(self, x, seg, degraded_image): - x_s = self.shortcut(x, seg, degraded_image) - - dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image))) - dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image))) - - out = x_s + dx - - return out - - def shortcut(self, x, seg, degraded_image): - if self.learned_shortcut: - x_s = self.conv_s(self.norm_s(x, seg, degraded_image)) - else: - x_s = x - return x_s - - def actvn(self, x): - return F.leaky_relu(x, 2e-1) - - -# ResNet block used in pix2pixHD -# We keep the same architecture as pix2pixHD. -class ResnetBlock(nn.Module): - def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): - super().__init__() - - pw = (kernel_size - 1) // 2 - self.conv_block = nn.Sequential( - nn.ReflectionPad2d(pw), - norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), - activation, - nn.ReflectionPad2d(pw), - norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), - ) - - def forward(self, x): - y = self.conv_block(x) - out = x + y - return out - - -# VGG architecter, used for the perceptual loss using a pretrained VGG network -class VGG19(torch.nn.Module): - def __init__(self, requires_grad=False): - super().__init__() - vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features - self.slice1 = torch.nn.Sequential() - self.slice2 = torch.nn.Sequential() - self.slice3 = torch.nn.Sequential() - self.slice4 = torch.nn.Sequential() - self.slice5 = torch.nn.Sequential() - for x in range(2): - self.slice1.add_module(str(x), vgg_pretrained_features[x]) - for x in range(2, 7): - self.slice2.add_module(str(x), vgg_pretrained_features[x]) - for x in range(7, 12): - self.slice3.add_module(str(x), vgg_pretrained_features[x]) - for x in range(12, 21): - self.slice4.add_module(str(x), vgg_pretrained_features[x]) - for x in range(21, 30): - self.slice5.add_module(str(x), vgg_pretrained_features[x]) - if not requires_grad: - for param in self.parameters(): - param.requires_grad = False - - def forward(self, X): - h_relu1 = self.slice1(X) - h_relu2 = self.slice2(h_relu1) - h_relu3 = self.slice3(h_relu2) - h_relu4 = self.slice4(h_relu3) - h_relu5 = self.slice5(h_relu4) - out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] - return out - - -class SPADEResnetBlock_non_spade(nn.Module): - def __init__(self, fin, fout, opt): - super().__init__() - # Attributes - self.learned_shortcut = fin != fout - fmiddle = min(fin, fout) - - self.opt = opt - # create conv layers - self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) - self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) - if self.learned_shortcut: - self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) - - # apply spectral norm if specified - if "spectral" in opt.norm_G: - self.conv_0 = spectral_norm(self.conv_0) - self.conv_1 = spectral_norm(self.conv_1) - if self.learned_shortcut: - self.conv_s = spectral_norm(self.conv_s) - - # define normalization layers - spade_config_str = opt.norm_G.replace("spectral", "") - self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt) - self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt) - if self.learned_shortcut: - self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt) - - # note the resnet block with SPADE also takes in |seg|, - # the semantic segmentation map as input - def forward(self, x, seg, degraded_image): - x_s = self.shortcut(x, seg, degraded_image) - - dx = self.conv_0(self.actvn(x)) - dx = self.conv_1(self.actvn(dx)) - - out = x_s + dx - - return out - - def shortcut(self, x, seg, degraded_image): - if self.learned_shortcut: - x_s = self.conv_s(x) - else: - x_s = x - return x_s - - def actvn(self, x): - return F.leaky_relu(x, 2e-1) diff --git a/Face_Enhancement/models/networks/base_network.py b/Face_Enhancement/models/networks/base_network.py deleted file mode 100644 index 13befe1..0000000 --- a/Face_Enhancement/models/networks/base_network.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch.nn as nn -from torch.nn import init - - -class BaseNetwork(nn.Module): - def __init__(self): - super(BaseNetwork, self).__init__() - - @staticmethod - def modify_commandline_options(parser, is_train): - return parser - - def print_network(self): - if isinstance(self, list): - self = self[0] - num_params = 0 - for param in self.parameters(): - num_params += param.numel() - print( - "Network [%s] was created. Total number of parameters: %.1f million." - % (type(self).__name__, num_params / 1000000) - ) - - def init_weights(self, init_type="normal", gain=0.02): - def init_func(m): - classname = m.__class__.__name__ - if classname.find("BatchNorm2d") != -1: - if hasattr(m, "weight") and m.weight is not None: - init.normal_(m.weight.data, 1.0, gain) - if hasattr(m, "bias") and m.bias is not None: - init.constant_(m.bias.data, 0.0) - elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): - if init_type == "normal": - init.normal_(m.weight.data, 0.0, gain) - elif init_type == "xavier": - init.xavier_normal_(m.weight.data, gain=gain) - elif init_type == "xavier_uniform": - init.xavier_uniform_(m.weight.data, gain=1.0) - elif init_type == "kaiming": - init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") - elif init_type == "orthogonal": - init.orthogonal_(m.weight.data, gain=gain) - elif init_type == "none": # uses pytorch's default init method - m.reset_parameters() - else: - raise NotImplementedError("initialization method [%s] is not implemented" % init_type) - if hasattr(m, "bias") and m.bias is not None: - init.constant_(m.bias.data, 0.0) - - self.apply(init_func) - - # propagate to children - for m in self.children(): - if hasattr(m, "init_weights"): - m.init_weights(init_type, gain) diff --git a/Face_Enhancement/models/networks/encoder.py b/Face_Enhancement/models/networks/encoder.py deleted file mode 100644 index 4ff05e8..0000000 --- a/Face_Enhancement/models/networks/encoder.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from .base_network import BaseNetwork -from .normalization import get_nonspade_norm_layer - - -class ConvEncoder(BaseNetwork): - """ Same architecture as the image discriminator """ - - def __init__(self, opt): - super().__init__() - - kw = 3 - pw = int(np.ceil((kw - 1.0) / 2)) - ndf = opt.ngf - norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) - self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) - self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) - self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) - self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) - self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) - if opt.crop_size >= 256: - self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) - - self.so = s0 = 4 - self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) - self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) - - self.actvn = nn.LeakyReLU(0.2, False) - self.opt = opt - - def forward(self, x): - if x.size(2) != 256 or x.size(3) != 256: - x = F.interpolate(x, size=(256, 256), mode="bilinear") - - x = self.layer1(x) - x = self.layer2(self.actvn(x)) - x = self.layer3(self.actvn(x)) - x = self.layer4(self.actvn(x)) - x = self.layer5(self.actvn(x)) - if self.opt.crop_size >= 256: - x = self.layer6(self.actvn(x)) - x = self.actvn(x) - - x = x.view(x.size(0), -1) - mu = self.fc_mu(x) - logvar = self.fc_var(x) - - return mu, logvar diff --git a/Face_Enhancement/models/networks/generator.py b/Face_Enhancement/models/networks/generator.py deleted file mode 100644 index c93fd56..0000000 --- a/Face_Enhancement/models/networks/generator.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from .base_network import BaseNetwork -from .normalization import get_nonspade_norm_layer -from .architecture import ResnetBlock as ResnetBlock -from .architecture import SPADEResnetBlock as SPADEResnetBlock -from .architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade - - -class SPADEGenerator(BaseNetwork): - @staticmethod - def modify_commandline_options(parser, is_train): - parser.set_defaults(norm_G="spectralspadesyncbatch3x3") - parser.add_argument( - "--num_upsampling_layers", - choices=("normal", "more", "most"), - default="normal", - help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator", - ) - - return parser - - def __init__(self, opt): - super().__init__() - self.opt = opt - nf = opt.ngf - - self.sw, self.sh = self.compute_latent_vector_size(opt) - - print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh)) - - if opt.use_vae: - # In case of VAE, we will sample from random z vector - self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh) - else: - # Otherwise, we make the network deterministic by starting with - # downsampled segmentation map instead of random z - if self.opt.no_parsing_map: - self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1) - else: - self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) - - if self.opt.injection_layer == "all" or self.opt.injection_layer == "1": - self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) - else: - self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) - - if self.opt.injection_layer == "all" or self.opt.injection_layer == "2": - self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) - self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt) - - else: - self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) - self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) - - if self.opt.injection_layer == "all" or self.opt.injection_layer == "3": - self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt) - else: - self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt) - - if self.opt.injection_layer == "all" or self.opt.injection_layer == "4": - self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt) - else: - self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt) - - if self.opt.injection_layer == "all" or self.opt.injection_layer == "5": - self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt) - else: - self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt) - - if self.opt.injection_layer == "all" or self.opt.injection_layer == "6": - self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt) - else: - self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt) - - final_nc = nf - - if opt.num_upsampling_layers == "most": - self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt) - final_nc = nf // 2 - - self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) - - self.up = nn.Upsample(scale_factor=2) - - def compute_latent_vector_size(self, opt): - if opt.num_upsampling_layers == "normal": - num_up_layers = 5 - elif opt.num_upsampling_layers == "more": - num_up_layers = 6 - elif opt.num_upsampling_layers == "most": - num_up_layers = 7 - else: - raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers) - - sw = opt.load_size // (2 ** num_up_layers) - sh = round(sw / opt.aspect_ratio) - - return sw, sh - - def forward(self, input, degraded_image, z=None): - seg = input - - if self.opt.use_vae: - # we sample z from unit normal and reshape the tensor - if z is None: - z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device()) - x = self.fc(z) - x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw) - else: - # we downsample segmap and run convolution - if self.opt.no_parsing_map: - x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear") - else: - x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest") - x = self.fc(x) - - x = self.head_0(x, seg, degraded_image) - - x = self.up(x) - x = self.G_middle_0(x, seg, degraded_image) - - if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most": - x = self.up(x) - - x = self.G_middle_1(x, seg, degraded_image) - - x = self.up(x) - x = self.up_0(x, seg, degraded_image) - x = self.up(x) - x = self.up_1(x, seg, degraded_image) - x = self.up(x) - x = self.up_2(x, seg, degraded_image) - x = self.up(x) - x = self.up_3(x, seg, degraded_image) - - if self.opt.num_upsampling_layers == "most": - x = self.up(x) - x = self.up_4(x, seg, degraded_image) - - x = self.conv_img(F.leaky_relu(x, 2e-1)) - x = F.tanh(x) - - return x - - -class Pix2PixHDGenerator(BaseNetwork): - @staticmethod - def modify_commandline_options(parser, is_train): - parser.add_argument( - "--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG" - ) - parser.add_argument( - "--resnet_n_blocks", - type=int, - default=9, - help="number of residual blocks in the global generator network", - ) - parser.add_argument( - "--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block" - ) - parser.add_argument( - "--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution" - ) - # parser.set_defaults(norm_G='instance') - return parser - - def __init__(self, opt): - super().__init__() - input_nc = 3 - - # print("xxxxx") - # print(opt.norm_G) - norm_layer = get_nonspade_norm_layer(opt, opt.norm_G) - activation = nn.ReLU(False) - - model = [] - - # initial conv - model += [ - nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2), - norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)), - activation, - ] - - # downsample - mult = 1 - for i in range(opt.resnet_n_downsample): - model += [ - norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)), - activation, - ] - mult *= 2 - - # resnet blocks - for i in range(opt.resnet_n_blocks): - model += [ - ResnetBlock( - opt.ngf * mult, - norm_layer=norm_layer, - activation=activation, - kernel_size=opt.resnet_kernel_size, - ) - ] - - # upsample - for i in range(opt.resnet_n_downsample): - nc_in = int(opt.ngf * mult) - nc_out = int((opt.ngf * mult) / 2) - model += [ - norm_layer( - nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1) - ), - activation, - ] - mult = mult // 2 - - # final output conv - model += [ - nn.ReflectionPad2d(3), - nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0), - nn.Tanh(), - ] - - self.model = nn.Sequential(*model) - - def forward(self, input, degraded_image, z=None): - return self.model(degraded_image) - diff --git a/Face_Enhancement/models/networks/normalization.py b/Face_Enhancement/models/networks/normalization.py deleted file mode 100644 index 63b3b93..0000000 --- a/Face_Enhancement/models/networks/normalization.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import re -import torch -import torch.nn as nn -import torch.nn.functional as F -from .sync_batchnorm import SynchronizedBatchNorm2d -import torch.nn.utils.spectral_norm as spectral_norm - - -def get_nonspade_norm_layer(opt, norm_type="instance"): - # helper function to get # output channels of the previous layer - def get_out_channel(layer): - if hasattr(layer, "out_channels"): - return getattr(layer, "out_channels") - return layer.weight.size(0) - - # this function will be returned - def add_norm_layer(layer): - nonlocal norm_type - if norm_type.startswith("spectral"): - layer = spectral_norm(layer) - subnorm_type = norm_type[len("spectral") :] - - if subnorm_type == "none" or len(subnorm_type) == 0: - return layer - - # remove bias in the previous layer, which is meaningless - # since it has no effect after normalization - if getattr(layer, "bias", None) is not None: - delattr(layer, "bias") - layer.register_parameter("bias", None) - - if subnorm_type == "batch": - norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) - elif subnorm_type == "sync_batch": - norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) - elif subnorm_type == "instance": - norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) - else: - raise ValueError("normalization layer %s is not recognized" % subnorm_type) - - return nn.Sequential(layer, norm_layer) - - return add_norm_layer - - -class SPADE(nn.Module): - def __init__(self, config_text, norm_nc, label_nc, opt): - super().__init__() - - assert config_text.startswith("spade") - parsed = re.search("spade(\D+)(\d)x\d", config_text) - param_free_norm_type = str(parsed.group(1)) - ks = int(parsed.group(2)) - self.opt = opt - if param_free_norm_type == "instance": - self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) - elif param_free_norm_type == "syncbatch": - self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) - elif param_free_norm_type == "batch": - self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) - else: - raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type) - - # The dimension of the intermediate embedding space. Yes, hardcoded. - nhidden = 128 - - pw = ks // 2 - - if self.opt.no_parsing_map: - self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()) - else: - self.mlp_shared = nn.Sequential( - nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU() - ) - self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) - self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) - - def forward(self, x, segmap, degraded_image): - - # Part 1. generate parameter-free normalized activations - normalized = self.param_free_norm(x) - - # Part 2. produce scaling and bias conditioned on semantic map - segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") - degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear") - - if self.opt.no_parsing_map: - actv = self.mlp_shared(degraded_face) - else: - actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1)) - gamma = self.mlp_gamma(actv) - beta = self.mlp_beta(actv) - - # apply scale and bias - out = normalized * (1 + gamma) + beta - - return out diff --git a/Face_Enhancement/models/networks/sync_batchnorm/__init__.py b/Face_Enhancement/models/networks/sync_batchnorm/__init__.py deleted file mode 100644 index 6d9b36c..0000000 --- a/Face_Enhancement/models/networks/sync_batchnorm/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- -# File : __init__.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -from .batchnorm import set_sbn_eps_mode -from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d -from .batchnorm import patch_sync_batchnorm, convert_model -from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py b/Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py deleted file mode 100644 index bf8d7a7..0000000 --- a/Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py +++ /dev/null @@ -1,412 +0,0 @@ -# -*- coding: utf-8 -*- -# File : batchnorm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import collections -import contextlib - -import torch -import torch.nn.functional as F - -from torch.nn.modules.batchnorm import _BatchNorm - -try: - from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast -except ImportError: - ReduceAddCoalesced = Broadcast = None - -try: - from jactorch.parallel.comm import SyncMaster - from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback -except ImportError: - from .comm import SyncMaster - from .replicate import DataParallelWithCallback - -__all__ = [ - 'set_sbn_eps_mode', - 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', - 'patch_sync_batchnorm', 'convert_model' -] - - -SBN_EPS_MODE = 'clamp' - - -def set_sbn_eps_mode(mode): - global SBN_EPS_MODE - assert mode in ('clamp', 'plus') - SBN_EPS_MODE = mode - - -def _sum_ft(tensor): - """sum over the first and last dimention""" - return tensor.sum(dim=0).sum(dim=-1) - - -def _unsqueeze_ft(tensor): - """add new dimensions at the front and the tail""" - return tensor.unsqueeze(0).unsqueeze(-1) - - -_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) -_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) - - -class _SynchronizedBatchNorm(_BatchNorm): - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): - assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' - - super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, - track_running_stats=track_running_stats) - - if not self.track_running_stats: - import warnings - warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') - - self._sync_master = SyncMaster(self._data_parallel_master) - - self._is_parallel = False - self._parallel_id = None - self._slave_pipe = None - - def forward(self, input): - # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. - if not (self._is_parallel and self.training): - return F.batch_norm( - input, self.running_mean, self.running_var, self.weight, self.bias, - self.training, self.momentum, self.eps) - - # Resize the input to (B, C, -1). - input_shape = input.size() - assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) - input = input.view(input.size(0), self.num_features, -1) - - # Compute the sum and square-sum. - sum_size = input.size(0) * input.size(2) - input_sum = _sum_ft(input) - input_ssum = _sum_ft(input ** 2) - - # Reduce-and-broadcast the statistics. - if self._parallel_id == 0: - mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) - else: - mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) - - # Compute the output. - if self.affine: - # MJY:: Fuse the multiplication for speed. - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) - else: - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) - - # Reshape it. - return output.view(input_shape) - - def __data_parallel_replicate__(self, ctx, copy_id): - self._is_parallel = True - self._parallel_id = copy_id - - # parallel_id == 0 means master device. - if self._parallel_id == 0: - ctx.sync_master = self._sync_master - else: - self._slave_pipe = ctx.sync_master.register_slave(copy_id) - - def _data_parallel_master(self, intermediates): - """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" - - # Always using same "device order" makes the ReduceAdd operation faster. - # Thanks to:: Tete Xiao (http://tetexiao.com/) - intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) - - to_reduce = [i[1][:2] for i in intermediates] - to_reduce = [j for i in to_reduce for j in i] # flatten - target_gpus = [i[1].sum.get_device() for i in intermediates] - - sum_size = sum([i[1].sum_size for i in intermediates]) - sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) - mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) - - broadcasted = Broadcast.apply(target_gpus, mean, inv_std) - - outputs = [] - for i, rec in enumerate(intermediates): - outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) - - return outputs - - def _compute_mean_std(self, sum_, ssum, size): - """Compute the mean and standard-deviation with sum and square-sum. This method - also maintains the moving average on the master device.""" - assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' - mean = sum_ / size - sumvar = ssum - sum_ * mean - unbias_var = sumvar / (size - 1) - bias_var = sumvar / size - - if hasattr(torch, 'no_grad'): - with torch.no_grad(): - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data - else: - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data - - if SBN_EPS_MODE == 'clamp': - return mean, bias_var.clamp(self.eps) ** -0.5 - elif SBN_EPS_MODE == 'plus': - return mean, (bias_var + self.eps) ** -0.5 - else: - raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) - - -class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): - r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a - mini-batch. - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm1d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm - - Args: - num_features: num_features from an expected input of size - `batch_size x num_features [x width]` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape:: - - Input: :math:`(N, C)` or :math:`(N, C, L)` - - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)' - .format(input.dim())) - - -class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch - of 3d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm2d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape:: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 4: - raise ValueError('expected 4D input (got {}D input)' - .format(input.dim())) - - -class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch - of 4d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm3d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm - or Spatio-temporal BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x depth x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape:: - - Input: :math:`(N, C, D, H, W)` - - Output: :math:`(N, C, D, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 5: - raise ValueError('expected 5D input (got {}D input)' - .format(input.dim())) - - -@contextlib.contextmanager -def patch_sync_batchnorm(): - import torch.nn as nn - - backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d - - nn.BatchNorm1d = SynchronizedBatchNorm1d - nn.BatchNorm2d = SynchronizedBatchNorm2d - nn.BatchNorm3d = SynchronizedBatchNorm3d - - yield - - nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup - - -def convert_model(module): - """Traverse the input module and its child recursively - and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d - to SynchronizedBatchNorm*N*d - - Args: - module: the input module needs to be convert to SyncBN model - - Examples: - >>> import torch.nn as nn - >>> import torchvision - >>> # m is a standard pytorch model - >>> m = torchvision.models.resnet18(True) - >>> m = nn.DataParallel(m) - >>> # after convert, m is using SyncBN - >>> m = convert_model(m) - """ - if isinstance(module, torch.nn.DataParallel): - mod = module.module - mod = convert_model(mod) - mod = DataParallelWithCallback(mod, device_ids=module.device_ids) - return mod - - mod = module - for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, - torch.nn.modules.batchnorm.BatchNorm2d, - torch.nn.modules.batchnorm.BatchNorm3d], - [SynchronizedBatchNorm1d, - SynchronizedBatchNorm2d, - SynchronizedBatchNorm3d]): - if isinstance(module, pth_module): - mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) - mod.running_mean = module.running_mean - mod.running_var = module.running_var - if module.affine: - mod.weight.data = module.weight.data.clone().detach() - mod.bias.data = module.bias.data.clone().detach() - - for name, child in module.named_children(): - mod.add_module(name, convert_model(child)) - - return mod diff --git a/Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py b/Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py deleted file mode 100644 index 18145c3..0000000 --- a/Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py +++ /dev/null @@ -1,74 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# File : batchnorm_reimpl.py -# Author : acgtyrant -# Date : 11/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import torch -import torch.nn as nn -import torch.nn.init as init - -__all__ = ['BatchNorm2dReimpl'] - - -class BatchNorm2dReimpl(nn.Module): - """ - A re-implementation of batch normalization, used for testing the numerical - stability. - - Author: acgtyrant - See also: - https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 - """ - def __init__(self, num_features, eps=1e-5, momentum=0.1): - super().__init__() - - self.num_features = num_features - self.eps = eps - self.momentum = momentum - self.weight = nn.Parameter(torch.empty(num_features)) - self.bias = nn.Parameter(torch.empty(num_features)) - self.register_buffer('running_mean', torch.zeros(num_features)) - self.register_buffer('running_var', torch.ones(num_features)) - self.reset_parameters() - - def reset_running_stats(self): - self.running_mean.zero_() - self.running_var.fill_(1) - - def reset_parameters(self): - self.reset_running_stats() - init.uniform_(self.weight) - init.zeros_(self.bias) - - def forward(self, input_): - batchsize, channels, height, width = input_.size() - numel = batchsize * height * width - input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) - sum_ = input_.sum(1) - sum_of_square = input_.pow(2).sum(1) - mean = sum_ / numel - sumvar = sum_of_square - sum_ * mean - - self.running_mean = ( - (1 - self.momentum) * self.running_mean - + self.momentum * mean.detach() - ) - unbias_var = sumvar / (numel - 1) - self.running_var = ( - (1 - self.momentum) * self.running_var - + self.momentum * unbias_var.detach() - ) - - bias_var = sumvar / numel - inv_std = 1 / (bias_var + self.eps).pow(0.5) - output = ( - (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * - self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) - - return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() - diff --git a/Face_Enhancement/models/networks/sync_batchnorm/comm.py b/Face_Enhancement/models/networks/sync_batchnorm/comm.py deleted file mode 100644 index 922f8c4..0000000 --- a/Face_Enhancement/models/networks/sync_batchnorm/comm.py +++ /dev/null @@ -1,137 +0,0 @@ -# -*- coding: utf-8 -*- -# File : comm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import queue -import collections -import threading - -__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] - - -class FutureResult(object): - """A thread-safe future implementation. Used only as one-to-one pipe.""" - - def __init__(self): - self._result = None - self._lock = threading.Lock() - self._cond = threading.Condition(self._lock) - - def put(self, result): - with self._lock: - assert self._result is None, 'Previous result has\'t been fetched.' - self._result = result - self._cond.notify() - - def get(self): - with self._lock: - if self._result is None: - self._cond.wait() - - res = self._result - self._result = None - return res - - -_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) -_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) - - -class SlavePipe(_SlavePipeBase): - """Pipe for master-slave communication.""" - - def run_slave(self, msg): - self.queue.put((self.identifier, msg)) - ret = self.result.get() - self.queue.put(True) - return ret - - -class SyncMaster(object): - """An abstract `SyncMaster` object. - - - During the replication, as the data parallel will trigger an callback of each module, all slave devices should - call `register(id)` and obtain an `SlavePipe` to communicate with the master. - - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, - and passed to a registered callback. - - After receiving the messages, the master device should gather the information and determine to message passed - back to each slave devices. - """ - - def __init__(self, master_callback): - """ - - Args: - master_callback: a callback to be invoked after having collected messages from slave devices. - """ - self._master_callback = master_callback - self._queue = queue.Queue() - self._registry = collections.OrderedDict() - self._activated = False - - def __getstate__(self): - return {'master_callback': self._master_callback} - - def __setstate__(self, state): - self.__init__(state['master_callback']) - - def register_slave(self, identifier): - """ - Register an slave device. - - Args: - identifier: an identifier, usually is the device id. - - Returns: a `SlavePipe` object which can be used to communicate with the master device. - - """ - if self._activated: - assert self._queue.empty(), 'Queue is not clean before next initialization.' - self._activated = False - self._registry.clear() - future = FutureResult() - self._registry[identifier] = _MasterRegistry(future) - return SlavePipe(identifier, self._queue, future) - - def run_master(self, master_msg): - """ - Main entry for the master device in each forward pass. - The messages were first collected from each devices (including the master device), and then - an callback will be invoked to compute the message to be sent back to each devices - (including the master device). - - Args: - master_msg: the message that the master want to send to itself. This will be placed as the first - message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. - - Returns: the message to be sent back to the master device. - - """ - self._activated = True - - intermediates = [(0, master_msg)] - for i in range(self.nr_slaves): - intermediates.append(self._queue.get()) - - results = self._master_callback(intermediates) - assert results[0][0] == 0, 'The first result should belongs to the master.' - - for i, res in results: - if i == 0: - continue - self._registry[i].result.put(res) - - for i in range(self.nr_slaves): - assert self._queue.get() is True - - return results[0][1] - - @property - def nr_slaves(self): - return len(self._registry) diff --git a/Face_Enhancement/models/networks/sync_batchnorm/replicate.py b/Face_Enhancement/models/networks/sync_batchnorm/replicate.py deleted file mode 100644 index b71c7b8..0000000 --- a/Face_Enhancement/models/networks/sync_batchnorm/replicate.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# File : replicate.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import functools - -from torch.nn.parallel.data_parallel import DataParallel - -__all__ = [ - 'CallbackContext', - 'execute_replication_callbacks', - 'DataParallelWithCallback', - 'patch_replication_callback' -] - - -class CallbackContext(object): - pass - - -def execute_replication_callbacks(modules): - """ - Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. - - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - - Note that, as all modules are isomorphism, we assign each sub-module with a context - (shared among multiple copies of this module on different devices). - Through this context, different copies can share some information. - - We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback - of any slave copies. - """ - master_copy = modules[0] - nr_modules = len(list(master_copy.modules())) - ctxs = [CallbackContext() for _ in range(nr_modules)] - - for i, module in enumerate(modules): - for j, m in enumerate(module.modules()): - if hasattr(m, '__data_parallel_replicate__'): - m.__data_parallel_replicate__(ctxs[j], i) - - -class DataParallelWithCallback(DataParallel): - """ - Data Parallel with a replication callback. - - An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by - original `replicate` function. - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - - Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - # sync_bn.__data_parallel_replicate__ will be invoked. - """ - - def replicate(self, module, device_ids): - modules = super(DataParallelWithCallback, self).replicate(module, device_ids) - execute_replication_callbacks(modules) - return modules - - -def patch_replication_callback(data_parallel): - """ - Monkey-patch an existing `DataParallel` object. Add the replication callback. - Useful when you have customized `DataParallel` implementation. - - Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) - > patch_replication_callback(sync_bn) - # this is equivalent to - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - """ - - assert isinstance(data_parallel, DataParallel) - - old_replicate = data_parallel.replicate - - @functools.wraps(old_replicate) - def new_replicate(module, device_ids): - modules = old_replicate(module, device_ids) - execute_replication_callbacks(modules) - return modules - - data_parallel.replicate = new_replicate diff --git a/Face_Enhancement/models/networks/sync_batchnorm/unittest.py b/Face_Enhancement/models/networks/sync_batchnorm/unittest.py deleted file mode 100644 index 998223a..0000000 --- a/Face_Enhancement/models/networks/sync_batchnorm/unittest.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -# File : unittest.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import unittest -import torch - - -class TorchTestCase(unittest.TestCase): - def assertTensorClose(self, x, y): - adiff = float((x - y).abs().max()) - if (y == 0).all(): - rdiff = 'NaN' - else: - rdiff = float((adiff / y).abs().max()) - - message = ( - 'Tensor close check failed\n' - 'adiff={}\n' - 'rdiff={}\n' - ).format(adiff, rdiff) - self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) - diff --git a/Face_Enhancement/models/pix2pix_model.py b/Face_Enhancement/models/pix2pix_model.py deleted file mode 100644 index 94cc8cf..0000000 --- a/Face_Enhancement/models/pix2pix_model.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -from . import networks -from ..util import util - - -class Pix2PixModel(torch.nn.Module): - @staticmethod - def modify_commandline_options(parser, is_train): - networks.modify_commandline_options(parser, is_train) - return parser - - def __init__(self, opt): - super().__init__() - self.opt = opt - self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor - self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor - - self.netG, self.netD, self.netE = self.initialize_networks(opt) - - # set loss functions - if opt.isTrain: - self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) - self.criterionFeat = torch.nn.L1Loss() - if not opt.no_vgg_loss: - self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids) - if opt.use_vae: - self.KLDLoss = networks.KLDLoss() - - # Entry point for all calls involving forward pass - # of deep networks. We used this approach since DataParallel module - # can't parallelize custom functions, we branch to different - # routines based on |mode|. - def forward(self, data, mode): - input_semantics, real_image, degraded_image = self.preprocess_input(data) - - if mode == "generator": - g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image) - return g_loss, generated - elif mode == "discriminator": - d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image) - return d_loss - elif mode == "encode_only": - z, mu, logvar = self.encode_z(real_image) - return mu, logvar - elif mode == "inference": - with torch.no_grad(): - fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image) - return fake_image - else: - raise ValueError("|mode| is invalid") - - def create_optimizers(self, opt): - G_params = list(self.netG.parameters()) - if opt.use_vae: - G_params += list(self.netE.parameters()) - if opt.isTrain: - D_params = list(self.netD.parameters()) - - beta1, beta2 = opt.beta1, opt.beta2 - if opt.no_TTUR: - G_lr, D_lr = opt.lr, opt.lr - else: - G_lr, D_lr = opt.lr / 2, opt.lr * 2 - - optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2)) - optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) - - return optimizer_G, optimizer_D - - def save(self, epoch): - util.save_network(self.netG, "G", epoch, self.opt) - util.save_network(self.netD, "D", epoch, self.opt) - if self.opt.use_vae: - util.save_network(self.netE, "E", epoch, self.opt) - - ############################################################################ - # Private helper methods - ############################################################################ - - def initialize_networks(self, opt): - netG = networks.define_G(opt) - netD = networks.define_D(opt) if opt.isTrain else None - netE = networks.define_E(opt) if opt.use_vae else None - - if not opt.isTrain or opt.continue_train: - netG = util.load_network(netG, "G", opt.which_epoch, opt) - if opt.isTrain: - netD = util.load_network(netD, "D", opt.which_epoch, opt) - if opt.use_vae: - netE = util.load_network(netE, "E", opt.which_epoch, opt) - - return netG, netD, netE - - # preprocess the input, such as moving the tensors to GPUs and - # transforming the label map to one-hot encoding - # |data|: dictionary of the input data - - def preprocess_input(self, data): - # move to GPU and change data types - # data['label'] = data['label'].long() - - if not self.opt.isTrain: - if self.use_gpu(): - data["label"] = data["label"].cuda() - data["image"] = data["image"].cuda() - return data["label"], data["image"], data["image"] - - ## While testing, the input image is the degraded face - if self.use_gpu(): - data["label"] = data["label"].cuda() - data["degraded_image"] = data["degraded_image"].cuda() - data["image"] = data["image"].cuda() - - # # create one-hot label map - # label_map = data['label'] - # bs, _, h, w = label_map.size() - # nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \ - # else self.opt.label_nc - # input_label = self.FloatTensor(bs, nc, h, w).zero_() - # input_semantics = input_label.scatter_(1, label_map, 1.0) - - return data["label"], data["image"], data["degraded_image"] - - def compute_generator_loss(self, input_semantics, degraded_image, real_image): - G_losses = {} - - fake_image, KLD_loss = self.generate_fake( - input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae - ) - - if self.opt.use_vae: - G_losses["KLD"] = KLD_loss - - pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image) - - G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False) - - if not self.opt.no_ganFeat_loss: - num_D = len(pred_fake) - GAN_Feat_loss = self.FloatTensor(1).fill_(0) - for i in range(num_D): # for each discriminator - # last output is the final prediction, so we exclude it - num_intermediate_outputs = len(pred_fake[i]) - 1 - for j in range(num_intermediate_outputs): # for each layer output - unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) - GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D - G_losses["GAN_Feat"] = GAN_Feat_loss - - if not self.opt.no_vgg_loss: - G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg - - return G_losses, fake_image - - def compute_discriminator_loss(self, input_semantics, degraded_image, real_image): - D_losses = {} - with torch.no_grad(): - fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image) - fake_image = fake_image.detach() - fake_image.requires_grad_() - - pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image) - - D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True) - D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True) - - return D_losses - - def encode_z(self, real_image): - mu, logvar = self.netE(real_image) - z = self.reparameterize(mu, logvar) - return z, mu, logvar - - def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False): - z = None - KLD_loss = None - if self.opt.use_vae: - z, mu, logvar = self.encode_z(real_image) - if compute_kld_loss: - KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld - - fake_image = self.netG(input_semantics, degraded_image, z=z) - - assert ( - not compute_kld_loss - ) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False" - - return fake_image, KLD_loss - - # Given fake and real image, return the prediction of discriminator - # for each fake and real image. - - def discriminate(self, input_semantics, fake_image, real_image): - - if self.opt.no_parsing_map: - fake_concat = fake_image - real_concat = real_image - else: - fake_concat = torch.cat([input_semantics, fake_image], dim=1) - real_concat = torch.cat([input_semantics, real_image], dim=1) - - # In Batch Normalization, the fake and real images are - # recommended to be in the same batch to avoid disparate - # statistics in fake and real images. - # So both fake and real images are fed to D all at once. - fake_and_real = torch.cat([fake_concat, real_concat], dim=0) - - discriminator_out = self.netD(fake_and_real) - - pred_fake, pred_real = self.divide_pred(discriminator_out) - - return pred_fake, pred_real - - # Take the prediction of fake and real images from the combined batch - def divide_pred(self, pred): - # the prediction contains the intermediate outputs of multiscale GAN, - # so it's usually a list - if type(pred) == list: - fake = [] - real = [] - for p in pred: - fake.append([tensor[: tensor.size(0) // 2] for tensor in p]) - real.append([tensor[tensor.size(0) // 2 :] for tensor in p]) - else: - fake = pred[: pred.size(0) // 2] - real = pred[pred.size(0) // 2 :] - - return fake, real - - def get_edges(self, t): - edge = self.ByteTensor(t.size()).zero_() - edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) - edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) - edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) - edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) - return edge.float() - - def reparameterize(self, mu, logvar): - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return eps.mul(std) + mu - - def use_gpu(self): - return torch.cuda.is_available() and len(self.opt.gpu_ids) > 0 diff --git a/Face_Enhancement/options/__init__.py b/Face_Enhancement/options/__init__.py deleted file mode 100644 index 59e481e..0000000 --- a/Face_Enhancement/options/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. diff --git a/Face_Enhancement/options/base_options.py b/Face_Enhancement/options/base_options.py deleted file mode 100644 index ac02059..0000000 --- a/Face_Enhancement/options/base_options.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright (c) Microsoft Corporation - -import argparse -import pickle -import torch -import sys -import os - -from ..util import util -from .. import models - - -class BaseOptions: - def __init__(self): - self.initialized = False - - def initialize(self, parser): - # experiment specifics - parser.add_argument( - "--name", - type=str, - default="label2coco", - help="name of the experiment. It decides where to store samples and models", - ) - - parser.add_argument( - "--gpu_ids", - type=str, - default="0", - help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU", - ) - parser.add_argument( - "--checkpoints_dir", - type=str, - default="./checkpoints", - help="models are saved here", - ) - parser.add_argument( - "--model", type=str, default="pix2pix", help="which model to use" - ) - parser.add_argument( - "--norm_G", - type=str, - default="spectralinstance", - help="instance normalization or batch normalization", - ) - parser.add_argument( - "--norm_D", - type=str, - default="spectralinstance", - help="instance normalization or batch normalization", - ) - parser.add_argument( - "--norm_E", - type=str, - default="spectralinstance", - help="instance normalization or batch normalization", - ) - parser.add_argument( - "--phase", type=str, default="train", help="train, val, test, etc" - ) - - # input/output sizes - parser.add_argument("--batchSize", type=int, default=1, help="input batch size") - parser.add_argument( - "--preprocess_mode", - type=str, - default="scale_width_and_crop", - help="scaling and cropping of images at load time.", - choices=( - "resize_and_crop", - "crop", - "scale_width", - "scale_width_and_crop", - "scale_shortside", - "scale_shortside_and_crop", - "fixed", - "none", - "resize", - ), - ) - parser.add_argument( - "--load_size", - type=int, - default=1024, - help="Scale images to this size. The final image will be cropped to --crop_size.", - ) - parser.add_argument( - "--crop_size", - type=int, - default=512, - help="Crop to the width of crop_size (after initially scaling the images to load_size.)", - ) - parser.add_argument( - "--aspect_ratio", - type=float, - default=1.0, - help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio", - ) - parser.add_argument( - "--label_nc", - type=int, - default=182, - help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.", - ) - parser.add_argument( - "--contain_dontcare_label", - action="store_true", - help="if the label map contains dontcare label (dontcare=255)", - ) - parser.add_argument( - "--output_nc", type=int, default=3, help="# of output image channels" - ) - - # for setting inputs - parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/") - parser.add_argument("--dataset_mode", type=str, default="coco") - parser.add_argument( - "--serial_batches", - action="store_true", - help="if true, takes images in order to make batches, otherwise takes them randomly", - ) - parser.add_argument( - "--no_flip", - action="store_true", - help="if specified, do not flip the images for data argumentation", - ) - parser.add_argument( - "--nThreads", default=0, type=int, help="# threads for loading data" - ) - parser.add_argument( - "--max_dataset_size", - type=int, - default=sys.maxsize, - help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", - ) - parser.add_argument( - "--load_from_opt_file", - action="store_true", - help="load the options from checkpoints and use that as default", - ) - parser.add_argument( - "--cache_filelist_write", - action="store_true", - help="saves the current filelist into a text file, so that it loads faster", - ) - parser.add_argument( - "--cache_filelist_read", - action="store_true", - help="reads from the file list cache", - ) - - # for displays - parser.add_argument( - "--display_winsize", type=int, default=400, help="display window size" - ) - - # for generator - parser.add_argument( - "--netG", - type=str, - default="spade", - help="selects model to use for netG (pix2pixhd | spade)", - ) - parser.add_argument( - "--ngf", type=int, default=64, help="# of gen filters in first conv layer" - ) - parser.add_argument( - "--init_type", - type=str, - default="xavier", - help="network initialization [normal|xavier|kaiming|orthogonal]", - ) - parser.add_argument( - "--init_variance", - type=float, - default=0.02, - help="variance of the initialization distribution", - ) - parser.add_argument( - "--z_dim", type=int, default=256, help="dimension of the latent z vector" - ) - parser.add_argument( - "--no_parsing_map", - action="store_true", - help="During training, we do not use the parsing map", - ) - - # for instance-wise features - parser.add_argument( - "--no_instance", - action="store_true", - help="if specified, do *not* add instance map as input", - ) - parser.add_argument( - "--nef", - type=int, - default=16, - help="# of encoder filters in the first conv layer", - ) - parser.add_argument( - "--use_vae", - action="store_true", - help="enable training with an image encoder.", - ) - parser.add_argument( - "--tensorboard_log", - action="store_true", - help="use tensorboard to record the resutls", - ) - - # parser.add_argument('--img_dir',) - parser.add_argument( - "--old_face_folder", - type=str, - default="", - help="The folder name of input old face", - ) - parser.add_argument( - "--old_face_label_folder", - type=str, - default="", - help="The folder name of input old face label", - ) - - parser.add_argument("--injection_layer", type=str, default="all", help="") - - self.initialized = True - return parser - - def gather_options(self): - # initialize parser with basic options - if not self.initialized: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser = self.initialize(parser) - - # get the basic options - opt, unknown = parser.parse_known_args() - - # modify model-related parser options - model_name = opt.model - model_option_setter = models.get_option_setter(model_name) - parser = model_option_setter(parser, self.isTrain) - - # modify dataset-related parser options - # dataset_mode = opt.dataset_mode - # dataset_option_setter = data.get_option_setter(dataset_mode) - # parser = dataset_option_setter(parser, self.isTrain) - - opt, unknown = parser.parse_known_args() - - # if there is opt_file, load it. - # The previous default options will be overwritten - if opt.load_from_opt_file: - parser = self.update_options_from_file(parser, opt) - - opt, unknown = parser.parse_known_args() - self.parser = parser - return opt - - def print_options(self, opt): - message = "" - message += "----------------- Options ---------------\n" - for k, v in sorted(vars(opt).items()): - comment = "" - default = self.parser.get_default(k) - if v != default: - comment = "\t[default: %s]" % str(default) - message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment) - message += "----------------- End -------------------" - # print(message) - - def option_file_path(self, opt, makedir=False): - expr_dir = os.path.join(opt.checkpoints_dir, opt.name) - if makedir: - util.mkdirs(expr_dir) - file_name = os.path.join(expr_dir, "opt") - return file_name - - def save_options(self, opt): - file_name = self.option_file_path(opt, makedir=True) - with open(file_name + ".txt", "wt") as opt_file: - for k, v in sorted(vars(opt).items()): - comment = "" - default = self.parser.get_default(k) - if v != default: - comment = "\t[default: %s]" % str(default) - opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment)) - - with open(file_name + ".pkl", "wb") as opt_file: - pickle.dump(opt, opt_file) - - def update_options_from_file(self, parser, opt): - new_opt = self.load_options(opt) - for k, v in sorted(vars(opt).items()): - if hasattr(new_opt, k) and v != getattr(new_opt, k): - new_val = getattr(new_opt, k) - parser.set_defaults(**{k: new_val}) - return parser - - def load_options(self, opt): - file_name = self.option_file_path(opt, makedir=False) - new_opt = pickle.load(open(file_name + ".pkl", "rb")) - return new_opt - - def parse(self): - - opt = self.gather_options() - opt.isTrain = self.isTrain # train or test - opt.contain_dontcare_label = False - - self.print_options(opt) - if opt.isTrain: - self.save_options(opt) - - # Set semantic_nc based on the option. - # This will be convenient in many places - opt.semantic_nc = ( - opt.label_nc - + (1 if opt.contain_dontcare_label else 0) - + (0 if opt.no_instance else 1) - ) - - str_ids = opt.gpu_ids.split(",") - opt.gpu_ids = [] - for str_id in str_ids: - int_id = int(str_id) - if int_id >= 0: - opt.gpu_ids.append(int_id) - - # set gpu ids - if torch.cuda.is_available() and len(opt.gpu_ids) > 0: - if torch.cuda.device_count() > opt.gpu_ids[0]: - try: - torch.cuda.set_device(opt.gpu_ids[0]) - except: - print("Failed to set GPU device. Using CPU...") - - else: - print("Invalid GPU ID. Using CPU...") - - # assert (len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids)) - - self.opt = opt - return self.opt diff --git a/Face_Enhancement/options/test_options.py b/Face_Enhancement/options/test_options.py deleted file mode 100644 index 4fc38bd..0000000 --- a/Face_Enhancement/options/test_options.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .base_options import BaseOptions - - -class TestOptions(BaseOptions): - def initialize(self, parser): - BaseOptions.initialize(self, parser) - parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.") - parser.add_argument( - "--which_epoch", - type=str, - default="latest", - help="which epoch to load? set to latest to use latest cached model", - ) - parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run") - - parser.set_defaults( - preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256 - ) - parser.set_defaults(serial_batches=True) - parser.set_defaults(no_flip=True) - parser.set_defaults(phase="test") - self.isTrain = False - return parser diff --git a/Face_Enhancement/test_face.py b/Face_Enhancement/test_face.py deleted file mode 100644 index 70a154a..0000000 --- a/Face_Enhancement/test_face.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Microsoft Corporation - -import torchvision.transforms as T -import warnings - -from .options.test_options import TestOptions -from .models.pix2pix_model import Pix2PixModel -from .data import create_dataloader - -warnings.filterwarnings("ignore", category=UserWarning) -tensor2image = T.ToPILImage() - - -def test_face(face_images: list, args: dict) -> list: - opt = TestOptions().parse() - for K, V in args.items(): - setattr(opt, K, V) - - dataloader = create_dataloader(opt, face_images) - images = [] - - model = Pix2PixModel(opt) - model.eval() - - for i, data_i in enumerate(dataloader): - if i * opt.batchSize >= opt.how_many: - break - - generated = model(data_i, mode="inference") - - for b in range(generated.shape[0]): - images.append(tensor2image((generated[b] + 1) / 2)) - - return images diff --git a/Face_Enhancement/util/__init__.py b/Face_Enhancement/util/__init__.py deleted file mode 100644 index 59e481e..0000000 --- a/Face_Enhancement/util/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. diff --git a/Face_Enhancement/util/iter_counter.py b/Face_Enhancement/util/iter_counter.py deleted file mode 100644 index 277bb67..0000000 --- a/Face_Enhancement/util/iter_counter.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import os -import time -import numpy as np - - -# Helper class that keeps track of training iterations -class IterationCounter: - def __init__(self, opt, dataset_size): - self.opt = opt - self.dataset_size = dataset_size - - self.first_epoch = 1 - self.total_epochs = opt.niter + opt.niter_decay - self.epoch_iter = 0 # iter number within each epoch - self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt") - if opt.isTrain and opt.continue_train: - try: - self.first_epoch, self.epoch_iter = np.loadtxt( - self.iter_record_path, delimiter=",", dtype=int - ) - print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter)) - except: - print( - "Could not load iteration record at %s. Starting from beginning." % self.iter_record_path - ) - - self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter - - # return the iterator of epochs for the training - def training_epochs(self): - return range(self.first_epoch, self.total_epochs + 1) - - def record_epoch_start(self, epoch): - self.epoch_start_time = time.time() - self.epoch_iter = 0 - self.last_iter_time = time.time() - self.current_epoch = epoch - - def record_one_iteration(self): - current_time = time.time() - - # the last remaining batch is dropped (see data/__init__.py), - # so we can assume batch size is always opt.batchSize - self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize - self.last_iter_time = current_time - self.total_steps_so_far += self.opt.batchSize - self.epoch_iter += self.opt.batchSize - - def record_epoch_end(self): - current_time = time.time() - self.time_per_epoch = current_time - self.epoch_start_time - print( - "End of epoch %d / %d \t Time Taken: %d sec" - % (self.current_epoch, self.total_epochs, self.time_per_epoch) - ) - if self.current_epoch % self.opt.save_epoch_freq == 0: - np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d") - print("Saved current iteration count at %s." % self.iter_record_path) - - def record_current_iter(self): - np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d") - print("Saved current iteration count at %s." % self.iter_record_path) - - def needs_saving(self): - return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize - - def needs_printing(self): - return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize - - def needs_displaying(self): - return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize diff --git a/Face_Enhancement/util/util.py b/Face_Enhancement/util/util.py deleted file mode 100644 index c5d67f1..0000000 --- a/Face_Enhancement/util/util.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import re -import importlib -import torch -from argparse import Namespace -import numpy as np -from PIL import Image -import os -import argparse -import dill as pickle - - -def save_obj(obj, name): - with open(name, "wb") as f: - pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) - - -def load_obj(name): - with open(name, "rb") as f: - return pickle.load(f) - - -def copyconf(default_opt, **kwargs): - conf = argparse.Namespace(**vars(default_opt)) - for key in kwargs: - print(key, kwargs[key]) - setattr(conf, key, kwargs[key]) - return conf - - -# Converts a Tensor into a Numpy array -# |imtype|: the desired type of the converted numpy array -def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): - if isinstance(image_tensor, list): - image_numpy = [] - for i in range(len(image_tensor)): - image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) - return image_numpy - - if image_tensor.dim() == 4: - # transform each image in the batch - images_np = [] - for b in range(image_tensor.size(0)): - one_image = image_tensor[b] - one_image_np = tensor2im(one_image) - images_np.append(one_image_np.reshape(1, *one_image_np.shape)) - images_np = np.concatenate(images_np, axis=0) - - return images_np - - if image_tensor.dim() == 2: - image_tensor = image_tensor.unsqueeze(0) - image_numpy = image_tensor.detach().cpu().float().numpy() - if normalize: - image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 - else: - image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 - image_numpy = np.clip(image_numpy, 0, 255) - if image_numpy.shape[2] == 1: - image_numpy = image_numpy[:, :, 0] - return image_numpy.astype(imtype) - - -# Converts a one-hot tensor into a colorful label map -def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): - if label_tensor.dim() == 4: - # transform each image in the batch - images_np = [] - for b in range(label_tensor.size(0)): - one_image = label_tensor[b] - one_image_np = tensor2label(one_image, n_label, imtype) - images_np.append(one_image_np.reshape(1, *one_image_np.shape)) - images_np = np.concatenate(images_np, axis=0) - # if tile: - # images_tiled = tile_images(images_np) - # return images_tiled - # else: - # images_np = images_np[0] - # return images_np - return images_np - - if label_tensor.dim() == 1: - return np.zeros((64, 64, 3), dtype=np.uint8) - if n_label == 0: - return tensor2im(label_tensor, imtype) - label_tensor = label_tensor.cpu().float() - if label_tensor.size()[0] > 1: - label_tensor = label_tensor.max(0, keepdim=True)[1] - label_tensor = Colorize(n_label)(label_tensor) - label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) - result = label_numpy.astype(imtype) - return result - - -def save_image(image_numpy, image_path, create_dir=False): - if create_dir: - os.makedirs(os.path.dirname(image_path), exist_ok=True) - if len(image_numpy.shape) == 2: - image_numpy = np.expand_dims(image_numpy, axis=2) - if image_numpy.shape[2] == 1: - image_numpy = np.repeat(image_numpy, 3, 2) - image_pil = Image.fromarray(image_numpy) - - # save to png - image_pil.save(image_path.replace(".jpg", ".png")) - - -def mkdirs(paths): - if isinstance(paths, list) and not isinstance(paths, str): - for path in paths: - mkdir(path) - else: - mkdir(paths) - - -def mkdir(path): - if not os.path.exists(path): - os.makedirs(path) - - -def atoi(text): - return int(text) if text.isdigit() else text - - -def natural_keys(text): - """ - alist.sort(key=natural_keys) sorts in human order - http://nedbatchelder.com/blog/200712/human_sorting.html - (See Toothy's implementation in the comments) - """ - return [atoi(c) for c in re.split("(\d+)", text)] - - -def natural_sort(items): - items.sort(key=natural_keys) - - -def str2bool(v): - if v.lower() in ("yes", "true", "t", "y", "1"): - return True - elif v.lower() in ("no", "false", "f", "n", "0"): - return False - else: - raise argparse.ArgumentTypeError("Boolean value expected.") - - -def find_class_in_module(target_cls_name, module): - target_cls_name = target_cls_name.replace("_", "").lower() - - if module == 'models.networks.generator': - from ..models.networks import generator as clslib - elif module == 'models.networks.encoder': - from ..models.networks import encoder as clslib - else: - raise NotImplementedError(f'Module: {module}') - - cls = None - for name, clsobj in clslib.__dict__.items(): - if name.lower() == target_cls_name: - cls = clsobj - - if cls is None: - print( - "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" - % (module, target_cls_name) - ) - exit(0) - - return cls - - -def save_network(net, label, epoch, opt): - save_filename = "%s_net_%s.pth" % (epoch, label) - save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) - torch.save(net.cpu().state_dict(), save_path) - if len(opt.gpu_ids) and torch.cuda.is_available(): - net.cuda() - - -def load_network(net, label, epoch, opt): - save_filename = "%s_net_%s.pth" % (epoch, label) - save_dir = os.path.join(opt.checkpoints_dir, opt.name) - save_path = os.path.join(save_dir, save_filename) - if os.path.exists(save_path): - weights = torch.load(save_path) - net.load_state_dict(weights) - return net - - -############################################################################### -# Code from -# https://github.com/ycszen/pytorch-seg/blob/master/transform.py -# Modified so it complies with the Citscape label map colors -############################################################################### -def uint82bin(n, count=8): - """returns the binary of integer n, count refers to amount of bits""" - return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) - - -class Colorize(object): - def __init__(self, n=35): - self.cmap = labelcolormap(n) - self.cmap = torch.from_numpy(self.cmap[:n]) - - def __call__(self, gray_image): - size = gray_image.size() - color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) - - for label in range(0, len(self.cmap)): - mask = (label == gray_image[0]).cpu() - color_image[0][mask] = self.cmap[label][0] - color_image[1][mask] = self.cmap[label][1] - color_image[2][mask] = self.cmap[label][2] - - return color_image diff --git a/Face_Enhancement/util/visualizer.py b/Face_Enhancement/util/visualizer.py deleted file mode 100644 index f08e73f..0000000 --- a/Face_Enhancement/util/visualizer.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import os -import ntpath -import time -from . import util -import scipy.misc - -try: - from io import BytesIO # Python 3.x -except ImportError: - from StringIO import StringIO # Python 2.7 -import torchvision.utils as vutils -from tensorboardX import SummaryWriter -import torch -import numpy as np - - -class Visualizer: - def __init__(self, opt): - self.opt = opt - self.tf_log = opt.isTrain and opt.tf_log - - self.tensorboard_log = opt.tensorboard_log - - self.win_size = opt.display_winsize - self.name = opt.name - if self.tensorboard_log: - - if self.opt.isTrain: - self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs") - if not os.path.exists(self.log_dir): - os.makedirs(self.log_dir) - self.writer = SummaryWriter(log_dir=self.log_dir) - else: - # print("hi :)") - self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir) - if not os.path.exists(self.log_dir): - os.makedirs(self.log_dir) - - if opt.isTrain: - self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt") - with open(self.log_name, "a") as log_file: - now = time.strftime("%c") - log_file.write("================ Training Loss (%s) ================\n" % now) - - # |visuals|: dictionary of images to display or save - def display_current_results(self, visuals, epoch, step): - - all_tensor = [] - if self.tensorboard_log: - - for key, tensor in visuals.items(): - all_tensor.append((tensor.data.cpu() + 1) / 2) - - output = torch.cat(all_tensor, 0) - img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False) - - if self.opt.isTrain: - self.writer.add_image("Face_SPADE/training_samples", img_grid, step) - else: - vutils.save_image( - output, - os.path.join(self.log_dir, str(step) + ".png"), - nrow=self.opt.batchSize, - padding=0, - normalize=False, - ) - - # errors: dictionary of error labels and values - def plot_current_errors(self, errors, step): - if self.tf_log: - for tag, value in errors.items(): - value = value.mean().float() - summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) - self.writer.add_summary(summary, step) - - if self.tensorboard_log: - - self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step) - self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step) - self.writer.add_scalars( - "Loss/GAN", - { - "G": errors["GAN"].mean().float(), - "D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2, - }, - step, - ) - - # errors: same format as |errors| of plotCurrentErrors - def print_current_errors(self, epoch, i, errors, t): - message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t) - for k, v in errors.items(): - v = v.mean().float() - message += "%s: %.3f " % (k, v) - - print(message) - with open(self.log_name, "a") as log_file: - log_file.write("%s\n" % message) - - def convert_visuals_to_numpy(self, visuals): - for key, t in visuals.items(): - tile = self.opt.batchSize > 8 - if "input_label" == key: - t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy - else: - t = util.tensor2im(t, tile=tile) - visuals[key] = t - return visuals - - # save image to the disk - def save_images(self, webpage, visuals, image_path): - visuals = self.convert_visuals_to_numpy(visuals) - - image_dir = webpage.get_image_dir() - short_path = ntpath.basename(image_path[0]) - name = os.path.splitext(short_path)[0] - - webpage.add_header(name) - ims = [] - txts = [] - links = [] - - for label, image_numpy in visuals.items(): - image_name = os.path.join(label, "%s.png" % (name)) - save_path = os.path.join(image_dir, image_name) - util.save_image(image_numpy, save_path, create_dir=True) - - ims.append(image_name) - txts.append(label) - links.append(image_name) - webpage.add_images(ims, txts, links, width=self.win_size) diff --git a/Global/data/Create_Bigfile.py b/Global/data/Create_Bigfile.py deleted file mode 100644 index 2df6ef3..0000000 --- a/Global/data/Create_Bigfile.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import os -import struct -from PIL import Image - -IMG_EXTENSIONS = [ - '.jpg', '.JPG', '.jpeg', '.JPEG', - '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', -] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def make_dataset(dir): - images = [] - assert os.path.isdir(dir), '%s is not a valid directory' % dir - - for root, _, fnames in sorted(os.walk(dir)): - for fname in fnames: - if is_image_file(fname): - #print(fname) - path = os.path.join(root, fname) - images.append(path) - - return images - -### Modify these 3 lines in your own environment -indir="/home/ziyuwan/workspace/data/temp_old" -target_folders=['VOC','Real_L_old','Real_RGB_old'] -out_dir ="/home/ziyuwan/workspace/data/temp_old" -### - -if os.path.exists(out_dir) is False: - os.makedirs(out_dir) - -# -for target_folder in target_folders: - curr_indir = os.path.join(indir, target_folder) - curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder))) - image_lists = make_dataset(curr_indir) - image_lists.sort() - with open(curr_out_file, 'wb') as wfid: - # write total image number - wfid.write(struct.pack('i', len(image_lists))) - for i, img_path in enumerate(image_lists): - # write file name first - img_name = os.path.basename(img_path) - img_name_bytes = img_name.encode('utf-8') - wfid.write(struct.pack('i', len(img_name_bytes))) - wfid.write(img_name_bytes) - # - # # write image data in - with open(img_path, 'rb') as img_fid: - img_bytes = img_fid.read() - wfid.write(struct.pack('i', len(img_bytes))) - wfid.write(img_bytes) - - if i % 1000 == 0: - print('write %d images done' % i) \ No newline at end of file diff --git a/Global/data/Load_Bigfile.py b/Global/data/Load_Bigfile.py deleted file mode 100644 index b34f1ec..0000000 --- a/Global/data/Load_Bigfile.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import io -import os -import struct -from PIL import Image - -class BigFileMemoryLoader(object): - def __load_bigfile(self): - print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024)) - with open(self.file_path, 'rb') as fid: - self.img_num = struct.unpack('i', fid.read(4))[0] - self.img_names = [] - self.img_bytes = [] - print('find total %d images' % self.img_num) - for i in range(self.img_num): - img_name_len = struct.unpack('i', fid.read(4))[0] - img_name = fid.read(img_name_len).decode('utf-8') - self.img_names.append(img_name) - img_bytes_len = struct.unpack('i', fid.read(4))[0] - self.img_bytes.append(fid.read(img_bytes_len)) - if i % 5000 == 0: - print('load %d images done' % i) - print('load all %d images done' % self.img_num) - - def __init__(self, file_path): - super(BigFileMemoryLoader, self).__init__() - self.file_path = file_path - self.__load_bigfile() - - def __getitem__(self, index): - try: - img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB') - return self.img_names[index], img - except Exception: - print('Image read error for index %d: %s' % (index, self.img_names[index])) - return self.__getitem__((index+1)%self.img_num) - - - def __len__(self): - return self.img_num diff --git a/Global/data/__init__.py b/Global/data/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/Global/data/base_data_loader.py b/Global/data/base_data_loader.py deleted file mode 100644 index 05b1283..0000000 --- a/Global/data/base_data_loader.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -class BaseDataLoader(): - def __init__(self): - pass - - def initialize(self, opt): - self.opt = opt - pass - - def load_data(): - return None - - - diff --git a/Global/data/base_dataset.py b/Global/data/base_dataset.py deleted file mode 100644 index 5f0ac56..0000000 --- a/Global/data/base_dataset.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch.utils.data as data -from PIL import Image -import torchvision.transforms as transforms -import numpy as np -import random - -class BaseDataset(data.Dataset): - def __init__(self): - super(BaseDataset, self).__init__() - - def name(self): - return 'BaseDataset' - - def initialize(self, opt): - pass - -def get_params(opt, size): - w, h = size - new_h = h - new_w = w - if opt.resize_or_crop == 'resize_and_crop': - new_h = new_w = opt.loadSize - - if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256 - - if w 0.5 - return {'crop_pos': (x, y), 'flip': flip} - -def get_transform(opt, params, method=Image.BICUBIC, normalize=True): - transform_list = [] - if 'resize' in opt.resize_or_crop: - osize = [opt.loadSize, opt.loadSize] - transform_list.append(transforms.Scale(osize, method)) - elif 'scale_width' in opt.resize_or_crop: - # transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it. - transform_list.append(transforms.Scale(256,method)) - - if 'crop' in opt.resize_or_crop: - if opt.isTrain: - transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) - else: - if opt.test_random_crop: - transform_list.append(transforms.RandomCrop(opt.fineSize)) - else: - transform_list.append(transforms.CenterCrop(opt.fineSize)) - - ## when testing, for ablation study, choose center_crop directly. - - - - if opt.resize_or_crop == 'none': - base = float(2 ** opt.n_downsample_global) - if opt.netG == 'local': - base *= (2 ** opt.n_local_enhancers) - transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) - - if opt.isTrain and not opt.no_flip: - transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) - - transform_list += [transforms.ToTensor()] - - if normalize: - transform_list += [transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - return transforms.Compose(transform_list) - -def normalize(): - return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - -def __make_power_2(img, base, method=Image.BICUBIC): - ow, oh = img.size - h = int(round(oh / base) * base) - w = int(round(ow / base) * base) - if (h == oh) and (w == ow): - return img - return img.resize((w, h), method) - -def __scale_width(img, target_width, method=Image.BICUBIC): - ow, oh = img.size - if (ow == target_width): - return img - w = target_width - h = int(target_width * oh / ow) - return img.resize((w, h), method) - -def __crop(img, pos, size): - ow, oh = img.size - x1, y1 = pos - tw = th = size - if (ow > tw or oh > th): - return img.crop((x1, y1, x1 + tw, y1 + th)) - return img - -def __flip(img, flip): - if flip: - return img.transpose(Image.FLIP_LEFT_RIGHT) - return img diff --git a/Global/data/custom_dataset_data_loader.py b/Global/data/custom_dataset_data_loader.py deleted file mode 100644 index 04cc032..0000000 --- a/Global/data/custom_dataset_data_loader.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch.utils.data -import random -from data.base_data_loader import BaseDataLoader -from data import online_dataset_for_old_photos as dts_ray_bigfile - - -def CreateDataset(opt): - dataset = None - if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B': - dataset = dts_ray_bigfile.UnPairOldPhotos_SR() - if opt.training_dataset=='mapping': - if opt.random_hole: - dataset = dts_ray_bigfile.PairOldPhotos_with_hole() - else: - dataset = dts_ray_bigfile.PairOldPhotos() - print("dataset [%s] was created" % (dataset.name())) - dataset.initialize(opt) - return dataset - -class CustomDatasetDataLoader(BaseDataLoader): - def name(self): - return 'CustomDatasetDataLoader' - - def initialize(self, opt): - BaseDataLoader.initialize(self, opt) - self.dataset = CreateDataset(opt) - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=opt.batchSize, - shuffle=not opt.serial_batches, - num_workers=int(opt.nThreads), - drop_last=True) - - def load_data(self): - return self.dataloader - - def __len__(self): - return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/Global/data/data_loader.py b/Global/data/data_loader.py deleted file mode 100644 index 02ccaed..0000000 --- a/Global/data/data_loader.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -def CreateDataLoader(opt): - from data.custom_dataset_data_loader import CustomDatasetDataLoader - data_loader = CustomDatasetDataLoader() - print(data_loader.name()) - data_loader.initialize(opt) - return data_loader diff --git a/Global/data/image_folder.py b/Global/data/image_folder.py deleted file mode 100644 index 8b1b956..0000000 --- a/Global/data/image_folder.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch.utils.data as data -from PIL import Image -import os - -IMG_EXTENSIONS = [ - '.jpg', '.JPG', '.jpeg', '.JPEG', - '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' -] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def make_dataset(dir): - images = [] - assert os.path.isdir(dir), '%s is not a valid directory' % dir - - for root, _, fnames in sorted(os.walk(dir)): - for fname in fnames: - if is_image_file(fname): - path = os.path.join(root, fname) - images.append(path) - - return images - - -def default_loader(path): - return Image.open(path).convert('RGB') - - -class ImageFolder(data.Dataset): - - def __init__(self, root, transform=None, return_paths=False, - loader=default_loader): - imgs = make_dataset(root) - if len(imgs) == 0: - raise(RuntimeError("Found 0 images in: " + root + "\n" - "Supported image extensions are: " + - ",".join(IMG_EXTENSIONS))) - - self.root = root - self.imgs = imgs - self.transform = transform - self.return_paths = return_paths - self.loader = loader - - def __getitem__(self, index): - path = self.imgs[index] - img = self.loader(path) - if self.transform is not None: - img = self.transform(img) - if self.return_paths: - return img, path - else: - return img - - def __len__(self): - return len(self.imgs) diff --git a/Global/detection.py b/Global/detection.py deleted file mode 100644 index 8e2bbf3..0000000 --- a/Global/detection.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Microsoft Corporation - -import torchvision.transforms as transforms -from PIL import Image, ImageFile -import torch.nn.functional as F -import torchvision as tv -import numpy as np -import warnings -import argparse -import torch -import gc -import os - -from .detection_models import networks - - -tensor2image = transforms.ToPILImage() - -warnings.filterwarnings("ignore", category=UserWarning) - -ImageFile.LOAD_TRUNCATED_IMAGES = True - - -def data_transforms(img, full_size, method=Image.BICUBIC): - if full_size == "full_size": - ow, oh = img.size - h = int(round(oh / 16) * 16) - w = int(round(ow / 16) * 16) - if (h == oh) and (w == ow): - return img - return img.resize((w, h), method) - - elif full_size == "scale_256": - ow, oh = img.size - pw, ph = ow, oh - if ow < oh: - ow = 256 - oh = ph / pw * 256 - else: - oh = 256 - ow = pw / ph * 256 - - h = int(round(oh / 16) * 16) - w = int(round(ow / 16) * 16) - if (h == ph) and (w == pw): - return img - return img.resize((w, h), method) - - -def scale_tensor(img_tensor, default_scale=256): - _, _, w, h = img_tensor.shape - if w < h: - ow = default_scale - oh = h / w * default_scale - else: - oh = default_scale - ow = w / h * default_scale - - oh = int(round(oh / 16) * 16) - ow = int(round(ow / 16) * 16) - - return F.interpolate(img_tensor, [ow, oh], mode="bilinear") - - -def blend_mask(img, mask): - - np_img = np.array(img).astype("float") - - return Image.fromarray( - (np_img * (1 - mask) + mask * 255.0).astype("uint8") - ).convert("RGB") - - -def main(config: argparse.Namespace, input_image: Image): - - model = networks.UNet( - in_channels=1, - out_channels=1, - depth=4, - conv_num=2, - wf=6, - padding=True, - batch_norm=True, - up_mode="upsample", - with_tanh=False, - sync_bn=True, - antialiasing=True, - ) - - ## load model - checkpoint_path = os.path.join( - os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt" - ) - checkpoint = torch.load(checkpoint_path, map_location="cpu") - model.load_state_dict(checkpoint["model_state"]) - print("model weights loaded") - - if torch.cuda.is_available() and config.GPU >= 0: - model.to(config.GPU) - else: - model.cpu() - - model.eval() - - print("processing...") - - transformed_image_PIL = data_transforms(input_image, config.input_size) - input_image = transformed_image_PIL.convert("L") - input_image = tv.transforms.ToTensor()(input_image) - input_image = tv.transforms.Normalize([0.5], [0.5])(input_image) - input_image = torch.unsqueeze(input_image, 0) - - _, _, ow, oh = input_image.shape - scratch_image_scale = scale_tensor(input_image) - - if torch.cuda.is_available() and config.GPU >= 0: - scratch_image_scale = scratch_image_scale.to(config.GPU) - else: - scratch_image_scale = scratch_image_scale.cpu() - - with torch.no_grad(): - P = torch.sigmoid(model(scratch_image_scale)) - - P = P.data.cpu() - P = F.interpolate(P, [ow, oh], mode="nearest") - - scratch_mask = torch.clamp((P >= 0.4).float(), 0.0, 1.0) * 255 - - gc.collect() - torch.cuda.empty_cache() - return tensor2image(scratch_mask[0].byte()), transformed_image_PIL - - -def global_detection( - input_image: Image, - gpu: int, - input_size: str, -) -> Image: - - config = argparse.Namespace() - config.GPU = gpu - config.input_size = input_size - - return main(config, input_image) diff --git a/Global/detection_models/__init__.py b/Global/detection_models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/Global/detection_models/antialiasing.py b/Global/detection_models/antialiasing.py deleted file mode 100644 index 78da8eb..0000000 --- a/Global/detection_models/antialiasing.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import torch.nn.parallel -import numpy as np -import torch.nn as nn -import torch.nn.functional as F - - -class Downsample(nn.Module): - # https://github.com/adobe/antialiased-cnns - - def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0): - super(Downsample, self).__init__() - self.filt_size = filt_size - self.pad_off = pad_off - self.pad_sizes = [ - int(1.0 * (filt_size - 1) / 2), - int(np.ceil(1.0 * (filt_size - 1) / 2)), - int(1.0 * (filt_size - 1) / 2), - int(np.ceil(1.0 * (filt_size - 1) / 2)), - ] - self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] - self.stride = stride - self.off = int((self.stride - 1) / 2.0) - self.channels = channels - - # print('Filter size [%i]'%filt_size) - if self.filt_size == 1: - a = np.array([1.0,]) - elif self.filt_size == 2: - a = np.array([1.0, 1.0]) - elif self.filt_size == 3: - a = np.array([1.0, 2.0, 1.0]) - elif self.filt_size == 4: - a = np.array([1.0, 3.0, 3.0, 1.0]) - elif self.filt_size == 5: - a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) - elif self.filt_size == 6: - a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) - elif self.filt_size == 7: - a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) - - filt = torch.Tensor(a[:, None] * a[None, :]) - filt = filt / torch.sum(filt) - self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) - - self.pad = get_pad_layer(pad_type)(self.pad_sizes) - - def forward(self, inp): - if self.filt_size == 1: - if self.pad_off == 0: - return inp[:, :, :: self.stride, :: self.stride] - else: - return self.pad(inp)[:, :, :: self.stride, :: self.stride] - else: - return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) - - -def get_pad_layer(pad_type): - if pad_type in ["refl", "reflect"]: - PadLayer = nn.ReflectionPad2d - elif pad_type in ["repl", "replicate"]: - PadLayer = nn.ReplicationPad2d - elif pad_type == "zero": - PadLayer = nn.ZeroPad2d - else: - print("Pad type [%s] not recognized" % pad_type) - return PadLayer diff --git a/Global/detection_models/networks.py b/Global/detection_models/networks.py deleted file mode 100644 index c3bff4c..0000000 --- a/Global/detection_models/networks.py +++ /dev/null @@ -1,332 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from .sync_batchnorm import DataParallelWithCallback -from .antialiasing import Downsample - - -class UNet(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - depth=5, - conv_num=2, - wf=6, - padding=True, - batch_norm=True, - up_mode="upsample", - with_tanh=False, - sync_bn=True, - antialiasing=True, - ): - """ - Implementation of - U-Net: Convolutional Networks for Biomedical Image Segmentation - (Ronneberger et al., 2015) - https://arxiv.org/abs/1505.04597 - Using the default arguments will yield the exact version used - in the original paper - Args: - in_channels (int): number of input channels - out_channels (int): number of output channels - depth (int): depth of the network - wf (int): number of filters in the first layer is 2**wf - padding (bool): if True, apply padding such that the input shape - is the same as the output. - This may introduce artifacts - batch_norm (bool): Use BatchNorm after layers with an - activation function - up_mode (str): one of 'upconv' or 'upsample'. - 'upconv' will use transposed convolutions for - learned upsampling. - 'upsample' will use bilinear upsampling. - """ - super().__init__() - assert up_mode in ("upconv", "upsample") - self.padding = padding - self.depth = depth - 1 - prev_channels = in_channels - - self.first = nn.Sequential( - *[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)] - ) - prev_channels = 2 ** wf - - self.down_path = nn.ModuleList() - self.down_sample = nn.ModuleList() - for i in range(depth): - if antialiasing and depth > 0: - self.down_sample.append( - nn.Sequential( - *[ - nn.ReflectionPad2d(1), - nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0), - nn.BatchNorm2d(prev_channels), - nn.LeakyReLU(0.2, True), - Downsample(channels=prev_channels, stride=2), - ] - ) - ) - else: - self.down_sample.append( - nn.Sequential( - *[ - nn.ReflectionPad2d(1), - nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0), - nn.BatchNorm2d(prev_channels), - nn.LeakyReLU(0.2, True), - ] - ) - ) - self.down_path.append( - UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm) - ) - prev_channels = 2 ** (wf + i + 1) - - self.up_path = nn.ModuleList() - for i in reversed(range(depth)): - self.up_path.append( - UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) - ) - prev_channels = 2 ** (wf + i) - - if with_tanh: - self.last = nn.Sequential( - *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()] - ) - else: - self.last = nn.Sequential( - *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)] - ) - - if sync_bn: - self = DataParallelWithCallback(self) - - def forward(self, x): - x = self.first(x) - - blocks = [] - for i, down_block in enumerate(self.down_path): - blocks.append(x) - x = self.down_sample[i](x) - x = down_block(x) - - for i, up in enumerate(self.up_path): - x = up(x, blocks[-i - 1]) - - return self.last(x) - - -class UNetConvBlock(nn.Module): - def __init__(self, conv_num, in_size, out_size, padding, batch_norm): - super(UNetConvBlock, self).__init__() - block = [] - - for _ in range(conv_num): - block.append(nn.ReflectionPad2d(padding=int(padding))) - block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0)) - if batch_norm: - block.append(nn.BatchNorm2d(out_size)) - block.append(nn.LeakyReLU(0.2, True)) - in_size = out_size - - self.block = nn.Sequential(*block) - - def forward(self, x): - out = self.block(x) - return out - - -class UNetUpBlock(nn.Module): - def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm): - super(UNetUpBlock, self).__init__() - if up_mode == "upconv": - self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) - elif up_mode == "upsample": - self.up = nn.Sequential( - nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False), - nn.ReflectionPad2d(1), - nn.Conv2d(in_size, out_size, kernel_size=3, padding=0), - ) - - self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm) - - def center_crop(self, layer, target_size): - _, _, layer_height, layer_width = layer.size() - diff_y = (layer_height - target_size[0]) // 2 - diff_x = (layer_width - target_size[1]) // 2 - return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])] - - def forward(self, x, bridge): - up = self.up(x) - crop1 = self.center_crop(bridge, up.shape[2:]) - out = torch.cat([up, crop1], 1) - out = self.conv_block(out) - - return out - - -class UnetGenerator(nn.Module): - """Create a Unet-based generator""" - - def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False): - """Construct a Unet generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, - image of size 128x128 will become of size 1x1 # at the bottleneck - ngf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer - We construct the U-Net from the innermost layer to the outermost layer. - It is a recursive process. - """ - super().__init__() - if norm_type == "BN": - norm_layer = nn.BatchNorm2d - elif norm_type == "IN": - norm_layer = nn.InstanceNorm2d - else: - raise NameError("Unknown norm layer") - - # construct unet structure - unet_block = UnetSkipConnectionBlock( - ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True - ) # add the innermost layer - for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters - unet_block = UnetSkipConnectionBlock( - ngf * 8, - ngf * 8, - input_nc=None, - submodule=unet_block, - norm_layer=norm_layer, - use_dropout=use_dropout, - ) - # gradually reduce the number of filters from ngf * 8 to ngf - unet_block = UnetSkipConnectionBlock( - ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer - ) - unet_block = UnetSkipConnectionBlock( - ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer - ) - unet_block = UnetSkipConnectionBlock( - ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer - ) - self.model = UnetSkipConnectionBlock( - output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer - ) # add the outermost layer - - def forward(self, input): - return self.model(input) - - -class UnetSkipConnectionBlock(nn.Module): - """Defines the Unet submodule with skip connection. - - -------------------identity---------------------- - |-- downsampling -- |submodule| -- upsampling --| - """ - - def __init__( - self, - outer_nc, - inner_nc, - input_nc=None, - submodule=None, - outermost=False, - innermost=False, - norm_layer=nn.BatchNorm2d, - use_dropout=False, - ): - """Construct a Unet submodule with skip connections. - Parameters: - outer_nc (int) -- the number of filters in the outer conv layer - inner_nc (int) -- the number of filters in the inner conv layer - input_nc (int) -- the number of channels in input images/features - submodule (UnetSkipConnectionBlock) -- previously defined submodules - outermost (bool) -- if this module is the outermost module - innermost (bool) -- if this module is the innermost module - norm_layer -- normalization layer - user_dropout (bool) -- if use dropout layers. - """ - super().__init__() - self.outermost = outermost - use_bias = norm_layer == nn.InstanceNorm2d - if input_nc is None: - input_nc = outer_nc - downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) - downrelu = nn.LeakyReLU(0.2, True) - downnorm = norm_layer(inner_nc) - uprelu = nn.LeakyReLU(0.2, True) - upnorm = norm_layer(outer_nc) - - if outermost: - upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) - down = [downconv] - up = [uprelu, upconv, nn.Tanh()] - model = down + [submodule] + up - elif innermost: - upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) - down = [downrelu, downconv] - up = [uprelu, upconv, upnorm] - model = down + up - else: - upconv = nn.ConvTranspose2d( - inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias - ) - down = [downrelu, downconv, downnorm] - up = [uprelu, upconv, upnorm] - - if use_dropout: - model = down + [submodule] + up + [nn.Dropout(0.5)] - else: - model = down + [submodule] + up - - self.model = nn.Sequential(*model) - - def forward(self, x): - if self.outermost: - return self.model(x) - else: # add skip connections - return torch.cat([x, self.model(x)], 1) - - -# ============================================ -# Network testing -# ============================================ -if __name__ == "__main__": - from torchsummary import summary - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - model = UNet_two_decoders( - in_channels=3, - out_channels1=3, - out_channels2=1, - depth=4, - conv_num=1, - wf=6, - padding=True, - batch_norm=True, - up_mode="upsample", - with_tanh=False, - ) - model.to(device) - - model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False) - model_pix2pix.to(device) - - print("customized unet:") - summary(model, (3, 256, 256)) - - print("cyclegan unet:") - summary(model_pix2pix, (3, 256, 256)) - - x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda() - g = make_dot(model(x)) - g.render("models/Digraph.gv", view=False) - diff --git a/Global/detection_models/sync_batchnorm/__init__.py b/Global/detection_models/sync_batchnorm/__init__.py deleted file mode 100644 index 6d9b36c..0000000 --- a/Global/detection_models/sync_batchnorm/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- -# File : __init__.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -from .batchnorm import set_sbn_eps_mode -from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d -from .batchnorm import patch_sync_batchnorm, convert_model -from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/Global/detection_models/sync_batchnorm/batchnorm.py b/Global/detection_models/sync_batchnorm/batchnorm.py deleted file mode 100644 index bf8d7a7..0000000 --- a/Global/detection_models/sync_batchnorm/batchnorm.py +++ /dev/null @@ -1,412 +0,0 @@ -# -*- coding: utf-8 -*- -# File : batchnorm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import collections -import contextlib - -import torch -import torch.nn.functional as F - -from torch.nn.modules.batchnorm import _BatchNorm - -try: - from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast -except ImportError: - ReduceAddCoalesced = Broadcast = None - -try: - from jactorch.parallel.comm import SyncMaster - from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback -except ImportError: - from .comm import SyncMaster - from .replicate import DataParallelWithCallback - -__all__ = [ - 'set_sbn_eps_mode', - 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', - 'patch_sync_batchnorm', 'convert_model' -] - - -SBN_EPS_MODE = 'clamp' - - -def set_sbn_eps_mode(mode): - global SBN_EPS_MODE - assert mode in ('clamp', 'plus') - SBN_EPS_MODE = mode - - -def _sum_ft(tensor): - """sum over the first and last dimention""" - return tensor.sum(dim=0).sum(dim=-1) - - -def _unsqueeze_ft(tensor): - """add new dimensions at the front and the tail""" - return tensor.unsqueeze(0).unsqueeze(-1) - - -_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) -_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) - - -class _SynchronizedBatchNorm(_BatchNorm): - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): - assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' - - super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, - track_running_stats=track_running_stats) - - if not self.track_running_stats: - import warnings - warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') - - self._sync_master = SyncMaster(self._data_parallel_master) - - self._is_parallel = False - self._parallel_id = None - self._slave_pipe = None - - def forward(self, input): - # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. - if not (self._is_parallel and self.training): - return F.batch_norm( - input, self.running_mean, self.running_var, self.weight, self.bias, - self.training, self.momentum, self.eps) - - # Resize the input to (B, C, -1). - input_shape = input.size() - assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) - input = input.view(input.size(0), self.num_features, -1) - - # Compute the sum and square-sum. - sum_size = input.size(0) * input.size(2) - input_sum = _sum_ft(input) - input_ssum = _sum_ft(input ** 2) - - # Reduce-and-broadcast the statistics. - if self._parallel_id == 0: - mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) - else: - mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) - - # Compute the output. - if self.affine: - # MJY:: Fuse the multiplication for speed. - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) - else: - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) - - # Reshape it. - return output.view(input_shape) - - def __data_parallel_replicate__(self, ctx, copy_id): - self._is_parallel = True - self._parallel_id = copy_id - - # parallel_id == 0 means master device. - if self._parallel_id == 0: - ctx.sync_master = self._sync_master - else: - self._slave_pipe = ctx.sync_master.register_slave(copy_id) - - def _data_parallel_master(self, intermediates): - """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" - - # Always using same "device order" makes the ReduceAdd operation faster. - # Thanks to:: Tete Xiao (http://tetexiao.com/) - intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) - - to_reduce = [i[1][:2] for i in intermediates] - to_reduce = [j for i in to_reduce for j in i] # flatten - target_gpus = [i[1].sum.get_device() for i in intermediates] - - sum_size = sum([i[1].sum_size for i in intermediates]) - sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) - mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) - - broadcasted = Broadcast.apply(target_gpus, mean, inv_std) - - outputs = [] - for i, rec in enumerate(intermediates): - outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) - - return outputs - - def _compute_mean_std(self, sum_, ssum, size): - """Compute the mean and standard-deviation with sum and square-sum. This method - also maintains the moving average on the master device.""" - assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' - mean = sum_ / size - sumvar = ssum - sum_ * mean - unbias_var = sumvar / (size - 1) - bias_var = sumvar / size - - if hasattr(torch, 'no_grad'): - with torch.no_grad(): - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data - else: - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data - - if SBN_EPS_MODE == 'clamp': - return mean, bias_var.clamp(self.eps) ** -0.5 - elif SBN_EPS_MODE == 'plus': - return mean, (bias_var + self.eps) ** -0.5 - else: - raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) - - -class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): - r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a - mini-batch. - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm1d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm - - Args: - num_features: num_features from an expected input of size - `batch_size x num_features [x width]` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape:: - - Input: :math:`(N, C)` or :math:`(N, C, L)` - - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm1d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 2 and input.dim() != 3: - raise ValueError('expected 2D or 3D input (got {}D input)' - .format(input.dim())) - - -class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch - of 3d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm2d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape:: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 4: - raise ValueError('expected 4D input (got {}D input)' - .format(input.dim())) - - -class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch - of 4d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm3d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device using - the statistics only on that device, which accelerated the computation and - is also easy to implement, but the statistics might be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm - or Spatio-temporal BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x depth x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` - - Shape:: - - Input: :math:`(N, C, D, H, W)` - - Output: :math:`(N, C, D, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm3d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 5: - raise ValueError('expected 5D input (got {}D input)' - .format(input.dim())) - - -@contextlib.contextmanager -def patch_sync_batchnorm(): - import torch.nn as nn - - backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d - - nn.BatchNorm1d = SynchronizedBatchNorm1d - nn.BatchNorm2d = SynchronizedBatchNorm2d - nn.BatchNorm3d = SynchronizedBatchNorm3d - - yield - - nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup - - -def convert_model(module): - """Traverse the input module and its child recursively - and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d - to SynchronizedBatchNorm*N*d - - Args: - module: the input module needs to be convert to SyncBN model - - Examples: - >>> import torch.nn as nn - >>> import torchvision - >>> # m is a standard pytorch model - >>> m = torchvision.models.resnet18(True) - >>> m = nn.DataParallel(m) - >>> # after convert, m is using SyncBN - >>> m = convert_model(m) - """ - if isinstance(module, torch.nn.DataParallel): - mod = module.module - mod = convert_model(mod) - mod = DataParallelWithCallback(mod, device_ids=module.device_ids) - return mod - - mod = module - for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, - torch.nn.modules.batchnorm.BatchNorm2d, - torch.nn.modules.batchnorm.BatchNorm3d], - [SynchronizedBatchNorm1d, - SynchronizedBatchNorm2d, - SynchronizedBatchNorm3d]): - if isinstance(module, pth_module): - mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) - mod.running_mean = module.running_mean - mod.running_var = module.running_var - if module.affine: - mod.weight.data = module.weight.data.clone().detach() - mod.bias.data = module.bias.data.clone().detach() - - for name, child in module.named_children(): - mod.add_module(name, convert_model(child)) - - return mod diff --git a/Global/detection_models/sync_batchnorm/batchnorm_reimpl.py b/Global/detection_models/sync_batchnorm/batchnorm_reimpl.py deleted file mode 100644 index 18145c3..0000000 --- a/Global/detection_models/sync_batchnorm/batchnorm_reimpl.py +++ /dev/null @@ -1,74 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# File : batchnorm_reimpl.py -# Author : acgtyrant -# Date : 11/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import torch -import torch.nn as nn -import torch.nn.init as init - -__all__ = ['BatchNorm2dReimpl'] - - -class BatchNorm2dReimpl(nn.Module): - """ - A re-implementation of batch normalization, used for testing the numerical - stability. - - Author: acgtyrant - See also: - https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 - """ - def __init__(self, num_features, eps=1e-5, momentum=0.1): - super().__init__() - - self.num_features = num_features - self.eps = eps - self.momentum = momentum - self.weight = nn.Parameter(torch.empty(num_features)) - self.bias = nn.Parameter(torch.empty(num_features)) - self.register_buffer('running_mean', torch.zeros(num_features)) - self.register_buffer('running_var', torch.ones(num_features)) - self.reset_parameters() - - def reset_running_stats(self): - self.running_mean.zero_() - self.running_var.fill_(1) - - def reset_parameters(self): - self.reset_running_stats() - init.uniform_(self.weight) - init.zeros_(self.bias) - - def forward(self, input_): - batchsize, channels, height, width = input_.size() - numel = batchsize * height * width - input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) - sum_ = input_.sum(1) - sum_of_square = input_.pow(2).sum(1) - mean = sum_ / numel - sumvar = sum_of_square - sum_ * mean - - self.running_mean = ( - (1 - self.momentum) * self.running_mean - + self.momentum * mean.detach() - ) - unbias_var = sumvar / (numel - 1) - self.running_var = ( - (1 - self.momentum) * self.running_var - + self.momentum * unbias_var.detach() - ) - - bias_var = sumvar / numel - inv_std = 1 / (bias_var + self.eps).pow(0.5) - output = ( - (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * - self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) - - return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() - diff --git a/Global/detection_models/sync_batchnorm/comm.py b/Global/detection_models/sync_batchnorm/comm.py deleted file mode 100644 index 922f8c4..0000000 --- a/Global/detection_models/sync_batchnorm/comm.py +++ /dev/null @@ -1,137 +0,0 @@ -# -*- coding: utf-8 -*- -# File : comm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import queue -import collections -import threading - -__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] - - -class FutureResult(object): - """A thread-safe future implementation. Used only as one-to-one pipe.""" - - def __init__(self): - self._result = None - self._lock = threading.Lock() - self._cond = threading.Condition(self._lock) - - def put(self, result): - with self._lock: - assert self._result is None, 'Previous result has\'t been fetched.' - self._result = result - self._cond.notify() - - def get(self): - with self._lock: - if self._result is None: - self._cond.wait() - - res = self._result - self._result = None - return res - - -_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) -_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) - - -class SlavePipe(_SlavePipeBase): - """Pipe for master-slave communication.""" - - def run_slave(self, msg): - self.queue.put((self.identifier, msg)) - ret = self.result.get() - self.queue.put(True) - return ret - - -class SyncMaster(object): - """An abstract `SyncMaster` object. - - - During the replication, as the data parallel will trigger an callback of each module, all slave devices should - call `register(id)` and obtain an `SlavePipe` to communicate with the master. - - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, - and passed to a registered callback. - - After receiving the messages, the master device should gather the information and determine to message passed - back to each slave devices. - """ - - def __init__(self, master_callback): - """ - - Args: - master_callback: a callback to be invoked after having collected messages from slave devices. - """ - self._master_callback = master_callback - self._queue = queue.Queue() - self._registry = collections.OrderedDict() - self._activated = False - - def __getstate__(self): - return {'master_callback': self._master_callback} - - def __setstate__(self, state): - self.__init__(state['master_callback']) - - def register_slave(self, identifier): - """ - Register an slave device. - - Args: - identifier: an identifier, usually is the device id. - - Returns: a `SlavePipe` object which can be used to communicate with the master device. - - """ - if self._activated: - assert self._queue.empty(), 'Queue is not clean before next initialization.' - self._activated = False - self._registry.clear() - future = FutureResult() - self._registry[identifier] = _MasterRegistry(future) - return SlavePipe(identifier, self._queue, future) - - def run_master(self, master_msg): - """ - Main entry for the master device in each forward pass. - The messages were first collected from each devices (including the master device), and then - an callback will be invoked to compute the message to be sent back to each devices - (including the master device). - - Args: - master_msg: the message that the master want to send to itself. This will be placed as the first - message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. - - Returns: the message to be sent back to the master device. - - """ - self._activated = True - - intermediates = [(0, master_msg)] - for i in range(self.nr_slaves): - intermediates.append(self._queue.get()) - - results = self._master_callback(intermediates) - assert results[0][0] == 0, 'The first result should belongs to the master.' - - for i, res in results: - if i == 0: - continue - self._registry[i].result.put(res) - - for i in range(self.nr_slaves): - assert self._queue.get() is True - - return results[0][1] - - @property - def nr_slaves(self): - return len(self._registry) diff --git a/Global/detection_models/sync_batchnorm/replicate.py b/Global/detection_models/sync_batchnorm/replicate.py deleted file mode 100644 index b71c7b8..0000000 --- a/Global/detection_models/sync_batchnorm/replicate.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# File : replicate.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import functools - -from torch.nn.parallel.data_parallel import DataParallel - -__all__ = [ - 'CallbackContext', - 'execute_replication_callbacks', - 'DataParallelWithCallback', - 'patch_replication_callback' -] - - -class CallbackContext(object): - pass - - -def execute_replication_callbacks(modules): - """ - Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. - - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - - Note that, as all modules are isomorphism, we assign each sub-module with a context - (shared among multiple copies of this module on different devices). - Through this context, different copies can share some information. - - We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback - of any slave copies. - """ - master_copy = modules[0] - nr_modules = len(list(master_copy.modules())) - ctxs = [CallbackContext() for _ in range(nr_modules)] - - for i, module in enumerate(modules): - for j, m in enumerate(module.modules()): - if hasattr(m, '__data_parallel_replicate__'): - m.__data_parallel_replicate__(ctxs[j], i) - - -class DataParallelWithCallback(DataParallel): - """ - Data Parallel with a replication callback. - - An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by - original `replicate` function. - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - - Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - # sync_bn.__data_parallel_replicate__ will be invoked. - """ - - def replicate(self, module, device_ids): - modules = super(DataParallelWithCallback, self).replicate(module, device_ids) - execute_replication_callbacks(modules) - return modules - - -def patch_replication_callback(data_parallel): - """ - Monkey-patch an existing `DataParallel` object. Add the replication callback. - Useful when you have customized `DataParallel` implementation. - - Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) - > patch_replication_callback(sync_bn) - # this is equivalent to - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - """ - - assert isinstance(data_parallel, DataParallel) - - old_replicate = data_parallel.replicate - - @functools.wraps(old_replicate) - def new_replicate(module, device_ids): - modules = old_replicate(module, device_ids) - execute_replication_callbacks(modules) - return modules - - data_parallel.replicate = new_replicate diff --git a/Global/detection_models/sync_batchnorm/unittest.py b/Global/detection_models/sync_batchnorm/unittest.py deleted file mode 100644 index 998223a..0000000 --- a/Global/detection_models/sync_batchnorm/unittest.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -# File : unittest.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import unittest -import torch - - -class TorchTestCase(unittest.TestCase): - def assertTensorClose(self, x, y): - adiff = float((x - y).abs().max()) - if (y == 0).all(): - rdiff = 'NaN' - else: - rdiff = float((adiff / y).abs().max()) - - message = ( - 'Tensor close check failed\n' - 'adiff={}\n' - 'rdiff={}\n' - ).format(adiff, rdiff) - self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) - diff --git a/Global/detection_util/util.py b/Global/detection_util/util.py deleted file mode 100644 index 4f8c34d..0000000 --- a/Global/detection_util/util.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import os -import sys -import time -import shutil -import platform -import numpy as np -from datetime import datetime - -import torch -import torchvision as tv -import torch.backends.cudnn as cudnn - -# from torch.utils.tensorboard import SummaryWriter - -import yaml -import matplotlib.pyplot as plt -from easydict import EasyDict as edict -import torchvision.utils as vutils - - -##### option parsing ###### -def print_options(config_dict): - print("------------ Options -------------") - for k, v in sorted(config_dict.items()): - print("%s: %s" % (str(k), str(v))) - print("-------------- End ----------------") - - -def save_options(config_dict): - from time import gmtime, strftime - - file_dir = os.path.join(config_dict["checkpoint_dir"], config_dict["name"]) - mkdir_if_not(file_dir) - file_name = os.path.join(file_dir, "opt.txt") - with open(file_name, "wt") as opt_file: - opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n") - opt_file.write("------------ Options -------------\n") - for k, v in sorted(config_dict.items()): - opt_file.write("%s: %s\n" % (str(k), str(v))) - opt_file.write("-------------- End ----------------\n") - - -def config_parse(config_file, options, save=True): - with open(config_file, "r") as stream: - config_dict = yaml.safe_load(stream) - config = edict(config_dict) - - for option_key, option_value in vars(options).items(): - config_dict[option_key] = option_value - config[option_key] = option_value - - if config.debug_mode: - config_dict["num_workers"] = 0 - config.num_workers = 0 - config.batch_size = 2 - if isinstance(config.gpu_ids, str): - config.gpu_ids = [int(x) for x in config.gpu_ids.split(",")][0] - - print_options(config_dict) - if save: - save_options(config_dict) - - return config - - -###### utility ###### -def to_np(x): - return x.cpu().numpy() - - -def prepare_device(use_gpu, gpu_ids): - if torch.cuda.is_available() and use_gpu: - cudnn.benchmark = True - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - if isinstance(gpu_ids, str): - gpu_ids = [int(x) for x in gpu_ids.split(",")] - torch.cuda.set_device(gpu_ids[0]) - device = torch.device("cuda:" + str(gpu_ids[0])) - else: - torch.cuda.set_device(gpu_ids) - device = torch.device("cuda:" + str(gpu_ids)) - print("running on GPU {}".format(gpu_ids)) - else: - device = torch.device("cpu") - print("running on CPU") - - return device - - -###### file system ###### -def get_dir_size(start_path="."): - total_size = 0 - for dirpath, dirnames, filenames in os.walk(start_path): - for f in filenames: - fp = os.path.join(dirpath, f) - total_size += os.path.getsize(fp) - return total_size - - -def mkdir_if_not(dir_path): - if not os.path.exists(dir_path): - os.makedirs(dir_path) - - -##### System related ###### -class Timer: - def __init__(self, msg): - self.msg = msg - self.start_time = None - - def __enter__(self): - self.start_time = time.time() - - def __exit__(self, exc_type, exc_value, exc_tb): - elapse = time.time() - self.start_time - print(self.msg % elapse) - - -###### interactive ###### -def get_size(start_path="."): - total_size = 0 - for dirpath, dirnames, filenames in os.walk(start_path): - for f in filenames: - fp = os.path.join(dirpath, f) - total_size += os.path.getsize(fp) - return total_size - - -def clean_tensorboard(directory): - tensorboard_list = os.listdir(directory) - SIZE_THRESH = 100000 - for tensorboard in tensorboard_list: - tensorboard = os.path.join(directory, tensorboard) - if get_size(tensorboard) < SIZE_THRESH: - print("deleting the empty tensorboard: ", tensorboard) - # - if os.path.isdir(tensorboard): - shutil.rmtree(tensorboard) - else: - os.remove(tensorboard) - - -def prepare_tensorboard(config, experiment_name=datetime.now().strftime("%Y-%m-%d %H-%M-%S")): - tensorboard_directory = os.path.join(config.checkpoint_dir, config.name, "tensorboard_logs") - mkdir_if_not(tensorboard_directory) - clean_tensorboard(tensorboard_directory) - tb_writer = SummaryWriter(os.path.join(tensorboard_directory, experiment_name), flush_secs=10) - - # try: - # shutil.copy('outputs/opt.txt', tensorboard_directory) - # except: - # print('cannot find file opt.txt') - return tb_writer - - -def tb_loss_logger(tb_writer, iter_index, loss_logger): - for tag, value in loss_logger.items(): - tb_writer.add_scalar(tag, scalar_value=value.item(), global_step=iter_index) - - -def tb_image_logger(tb_writer, iter_index, images_info, config): - ### Save and write the output into the tensorboard - tb_logger_path = os.path.join(config.output_dir, config.name, config.train_mode) - mkdir_if_not(tb_logger_path) - for tag, image in images_info.items(): - if tag == "test_image_prediction" or tag == "image_prediction": - continue - image = tv.utils.make_grid(image.cpu()) - image = torch.clamp(image, 0, 1) - tb_writer.add_image(tag, img_tensor=image, global_step=iter_index) - tv.transforms.functional.to_pil_image(image).save( - os.path.join(tb_logger_path, "{:06d}_{}.jpg".format(iter_index, tag)) - ) - - -def tb_image_logger_test(epoch, iter, images_info, config): - - url = os.path.join(config.output_dir, config.name, config.train_mode, "val_" + str(epoch)) - if not os.path.exists(url): - os.makedirs(url) - scratch_img = images_info["test_scratch_image"].data.cpu() - if config.norm_input: - scratch_img = (scratch_img + 1.0) / 2.0 - scratch_img = torch.clamp(scratch_img, 0, 1) - gt_mask = images_info["test_mask_image"].data.cpu() - predict_mask = images_info["test_scratch_prediction"].data.cpu() - - predict_hard_mask = (predict_mask.data.cpu() >= 0.5).float() - - imgs = torch.cat((scratch_img, predict_hard_mask, gt_mask), 0) - img_grid = vutils.save_image( - imgs, os.path.join(url, str(iter) + ".jpg"), nrow=len(scratch_img), padding=0, normalize=True - ) - - -def imshow(input_image, title=None, to_numpy=False): - inp = input_image - if to_numpy or type(input_image) is torch.Tensor: - inp = input_image.numpy() - - fig = plt.figure() - if inp.ndim == 2: - fig = plt.imshow(inp, cmap="gray", clim=[0, 255]) - else: - fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8)) - plt.axis("off") - fig.axes.get_xaxis().set_visible(False) - fig.axes.get_yaxis().set_visible(False) - plt.title(title) - - -###### vgg preprocessing ###### -def vgg_preprocess(tensor): - # input is RGB tensor which ranges in [0,1] - # output is BGR tensor which ranges in [0,255] - tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1) - # tensor_bgr = tensor[:, [2, 1, 0], ...] - tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view( - 1, 3, 1, 1 - ) - tensor_rst = tensor_bgr_ml * 255 - return tensor_rst - - -def torch_vgg_preprocess(tensor): - # pytorch version normalization - # note that both input and output are RGB tensors; - # input and output ranges in [0,1] - # normalize the tensor with mean and variance - tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1) - tensor_mc_norm = tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1) - return tensor_mc_norm - - -def network_gradient(net, gradient_on=True): - if gradient_on: - for param in net.parameters(): - param.requires_grad = True - else: - for param in net.parameters(): - param.requires_grad = False - return net diff --git a/Global/models/NonLocal_feature_mapping_model.py b/Global/models/NonLocal_feature_mapping_model.py deleted file mode 100644 index 43deb62..0000000 --- a/Global/models/NonLocal_feature_mapping_model.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Microsoft Corporation - -import torch.nn as nn -from . import networks - - -class Mapping_Model_with_mask(nn.Module): - def __init__( - self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None - ): - super(Mapping_Model_with_mask, self).__init__() - - norm_layer = networks.get_norm_layer(norm_type=norm) - activation = nn.ReLU(True) - model = [] - - tmp_nc = 64 - n_up = 4 - - for i in range(n_up): - ic = min(tmp_nc * (2**i), mc) - oc = min(tmp_nc * (2 ** (i + 1)), mc) - model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] - - self.before_NL = nn.Sequential(*model) - - if opt.NL_res: - self.NL = networks.NonLocalBlock2D_with_mask_Res( - mc, - mc, - opt.NL_fusion_method, - opt.correlation_renormalize, - opt.softmax_temperature, - opt.use_self, - opt.cosin_similarity, - ) - - print("using NL + Res...") - - model = [] - for i in range(n_blocks): - model += [ - networks.ResnetBlock( - mc, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - dilation=opt.mapping_net_dilation, - ) - ] - - for i in range(n_up - 1): - ic = min(64 * (2 ** (4 - i)), mc) - oc = min(64 * (2 ** (3 - i)), mc) - model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] - model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] - if opt.feat_dim > 0 and opt.feat_dim < 64: - model += [ - norm_layer(tmp_nc), - activation, - nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1), - ] - # model += [nn.Conv2d(64, 1, 1, 1, 0)] - self.after_NL = nn.Sequential(*model) - - def forward(self, input, mask): - x1 = self.before_NL(input) - del input - x2 = self.NL(x1, mask) - del x1, mask - x3 = self.after_NL(x2) - del x2 - - return x3 - - -class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention - def __init__( - self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None - ): - super(Mapping_Model_with_mask_2, self).__init__() - - norm_layer = networks.get_norm_layer(norm_type=norm) - activation = nn.ReLU(True) - model = [] - - tmp_nc = 64 - n_up = 4 - - for i in range(n_up): - ic = min(tmp_nc * (2**i), mc) - oc = min(tmp_nc * (2 ** (i + 1)), mc) - model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] - - for i in range(2): - model += [ - networks.ResnetBlock( - mc, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - dilation=opt.mapping_net_dilation, - ) - ] - - print("using multi-scale patch attention, conv combine + mask input...") - - self.before_NL = nn.Sequential(*model) - - if opt.mapping_exp == 1: - self.NL_scale_1 = networks.Patch_Attention_4(mc, mc, 8) - - model = [] - for i in range(2): - model += [ - networks.ResnetBlock( - mc, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - dilation=opt.mapping_net_dilation, - ) - ] - - self.res_block_1 = nn.Sequential(*model) - - if opt.mapping_exp == 1: - self.NL_scale_2 = networks.Patch_Attention_4(mc, mc, 4) - - model = [] - for i in range(2): - model += [ - networks.ResnetBlock( - mc, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - dilation=opt.mapping_net_dilation, - ) - ] - - self.res_block_2 = nn.Sequential(*model) - - if opt.mapping_exp == 1: - self.NL_scale_3 = networks.Patch_Attention_4(mc, mc, 2) - # self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2) - - model = [] - for i in range(2): - model += [ - networks.ResnetBlock( - mc, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - dilation=opt.mapping_net_dilation, - ) - ] - - for i in range(n_up - 1): - ic = min(64 * (2 ** (4 - i)), mc) - oc = min(64 * (2 ** (3 - i)), mc) - model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] - model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] - if opt.feat_dim > 0 and opt.feat_dim < 64: - model += [ - norm_layer(tmp_nc), - activation, - nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1), - ] - # model += [nn.Conv2d(64, 1, 1, 1, 0)] - self.after_NL = nn.Sequential(*model) - - def forward(self, input, mask): - x1 = self.before_NL(input) - x2 = self.NL_scale_1(x1, mask) - x3 = self.res_block_1(x2) - x4 = self.NL_scale_2(x3, mask) - x5 = self.res_block_2(x4) - x6 = self.NL_scale_3(x5, mask) - x7 = self.after_NL(x6) - return x7 - - def inference_forward(self, input, mask): - x1 = self.before_NL(input) - del input - x2 = self.NL_scale_1.inference_forward(x1, mask) - del x1 - x3 = self.res_block_1(x2) - del x2 - x4 = self.NL_scale_2.inference_forward(x3, mask) - del x3 - x5 = self.res_block_2(x4) - del x4 - x6 = self.NL_scale_3.inference_forward(x5, mask) - del x5 - x7 = self.after_NL(x6) - del x6 - return x7 diff --git a/Global/models/__init__.py b/Global/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/Global/models/base_model.py b/Global/models/base_model.py deleted file mode 100644 index 0622995..0000000 --- a/Global/models/base_model.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import os -import torch -import sys - - -class BaseModel(torch.nn.Module): - def name(self): - return "BaseModel" - - def initialize(self, opt): - self.opt = opt - self.gpu_ids = opt.gpu_ids - self.isTrain = opt.isTrain - self.Tensor = torch.cuda.FloatTensor if (torch.cuda.is_available() and self.gpu_ids) else torch.Tensor - self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) - - def set_input(self, input): - self.input = input - - def forward(self): - pass - - # used in test time, no backprop - def test(self): - pass - - def get_image_paths(self): - pass - - def optimize_parameters(self): - pass - - def get_current_visuals(self): - return self.input - - def get_current_errors(self): - return {} - - def save(self, label): - pass - - # helper saving function that can be used by subclasses - def save_network(self, network, network_label, epoch_label, gpu_ids): - save_filename = "%s_net_%s.pth" % (epoch_label, network_label) - save_path = os.path.join(self.save_dir, save_filename) - torch.save(network.cpu().state_dict(), save_path) - if len(gpu_ids) and torch.cuda.is_available(): - network.cuda() - - def save_optimizer(self, optimizer, optimizer_label, epoch_label): - save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) - save_path = os.path.join(self.save_dir, save_filename) - torch.save(optimizer.state_dict(), save_path) - - def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""): - save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) - if not save_dir: - save_dir = self.save_dir - save_path = os.path.join(save_dir, save_filename) - - if not os.path.isfile(save_path): - print("%s not exists yet!" % save_path) - else: - optimizer.load_state_dict(torch.load(save_path)) - - # helper loading function that can be used by subclasses - def load_network(self, network, network_label, epoch_label, save_dir=""): - save_filename = "%s_net_%s.pth" % (epoch_label, network_label) - if not save_dir: - save_dir = self.save_dir - - # print(save_dir) - # print(self.save_dir) - save_path = os.path.join(save_dir, save_filename) - if not os.path.isfile(save_path): - print("%s not exists yet!" % save_path) - # if network_label == 'G': - # raise('Generator must exist!') - else: - # network.load_state_dict(torch.load(save_path)) - try: - # print(save_path) - network.load_state_dict(torch.load(save_path)) - except: - pretrained_dict = torch.load(save_path) - model_dict = network.state_dict() - try: - pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} - network.load_state_dict(pretrained_dict) - # if self.opt.verbose: - print( - "Pretrained network %s has excessive layers; Only loading layers that are used" - % network_label - ) - except: - print( - "Pretrained network %s has fewer layers; The following are not initialized:" - % network_label - ) - for k, v in pretrained_dict.items(): - if v.size() == model_dict[k].size(): - model_dict[k] = v - - if sys.version_info >= (3, 0): - not_initialized = set() - else: - from sets import Set - - not_initialized = Set() - - for k, v in model_dict.items(): - if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): - not_initialized.add(k.split(".")[0]) - - print(sorted(not_initialized)) - network.load_state_dict(model_dict) - - def update_learning_rate(): - pass diff --git a/Global/models/mapping_model.py b/Global/models/mapping_model.py deleted file mode 100644 index bc69325..0000000 --- a/Global/models/mapping_model.py +++ /dev/null @@ -1,453 +0,0 @@ -# Copyright (c) Microsoft Corporation - -from .NonLocal_feature_mapping_model import * -from ..util.image_pool import ImagePool -from .base_model import BaseModel -from . import networks - -from torch.autograd import Variable -import torch.nn as nn -import torch - - -class Mapping_Model(nn.Module): - def __init__( - self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None - ): - super(Mapping_Model, self).__init__() - - norm_layer = networks.get_norm_layer(norm_type=norm) - activation = nn.ReLU(True) - model = [] - tmp_nc = 64 - n_up = 4 - - # print("Mapping: You are using the mapping model without global restoration.") - - for i in range(n_up): - ic = min(tmp_nc * (2**i), mc) - oc = min(tmp_nc * (2 ** (i + 1)), mc) - model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] - for i in range(n_blocks): - model += [ - networks.ResnetBlock( - mc, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - dilation=opt.mapping_net_dilation, - ) - ] - - for i in range(n_up - 1): - ic = min(64 * (2 ** (4 - i)), mc) - oc = min(64 * (2 ** (3 - i)), mc) - model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] - model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] - if opt.feat_dim > 0 and opt.feat_dim < 64: - model += [ - norm_layer(tmp_nc), - activation, - nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1), - ] - # model += [nn.Conv2d(64, 1, 1, 1, 0)] - self.model = nn.Sequential(*model) - - def forward(self, input): - return self.model(input) - - -class Pix2PixHDModel_Mapping(BaseModel): - def name(self): - return "Pix2PixHDModel_Mapping" - - def init_loss_filter( - self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2 - ): - flags = ( - True, - True, - use_gan_feat_loss, - use_vgg_loss, - True, - True, - use_smooth_l1, - stage_1_feat_l2, - ) - - def loss_filter( - g_feat_l2, - g_gan, - g_gan_feat, - g_vgg, - d_real, - d_fake, - smooth_l1, - stage_1_feat_l2, - ): - return [ - l - for (l, f) in zip( - ( - g_feat_l2, - g_gan, - g_gan_feat, - g_vgg, - d_real, - d_fake, - smooth_l1, - stage_1_feat_l2, - ), - flags, - ) - if f - ] - - return loss_filter - - def initialize(self, opt): - BaseModel.initialize(self, opt) - if opt.resize_or_crop != "none" or not opt.isTrain: - torch.backends.cudnn.benchmark = True - self.isTrain = opt.isTrain - input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc - - ##### define networks - # Generator network - netG_input_nc = input_nc - self.netG_A = networks.GlobalGenerator_DCDCv2( - netG_input_nc, - opt.output_nc, - opt.ngf, - opt.k_size, - opt.n_downsample_global, - networks.get_norm_layer(norm_type=opt.norm), - opt=opt, - ) - self.netG_B = networks.GlobalGenerator_DCDCv2( - netG_input_nc, - opt.output_nc, - opt.ngf, - opt.k_size, - opt.n_downsample_global, - networks.get_norm_layer(norm_type=opt.norm), - opt=opt, - ) - - if opt.non_local == "Setting_42" or opt.NL_use_mask: - if opt.mapping_exp == 1: - self.mapping_net = Mapping_Model_with_mask_2( - min(opt.ngf * 2**opt.n_downsample_global, opt.mc), - opt.map_mc, - n_blocks=opt.mapping_n_block, - opt=opt, - ) - else: - self.mapping_net = Mapping_Model_with_mask( - min(opt.ngf * 2**opt.n_downsample_global, opt.mc), - opt.map_mc, - n_blocks=opt.mapping_n_block, - opt=opt, - ) - else: - self.mapping_net = Mapping_Model( - min(opt.ngf * 2**opt.n_downsample_global, opt.mc), - opt.map_mc, - n_blocks=opt.mapping_n_block, - opt=opt, - ) - - self.mapping_net.apply(networks.weights_init) - - if opt.load_pretrain != "": - self.load_network( - self.mapping_net, "mapping_net", opt.which_epoch, opt.load_pretrain - ) - - if not opt.no_load_VAE: - - self.load_network( - self.netG_A, "G", opt.use_vae_which_epoch, opt.load_pretrainA - ) - self.load_network( - self.netG_B, "G", opt.use_vae_which_epoch, opt.load_pretrainB - ) - for param in self.netG_A.parameters(): - param.requires_grad = False - for param in self.netG_B.parameters(): - param.requires_grad = False - self.netG_A.eval() - self.netG_B.eval() - - if torch.cuda.is_available() and opt.gpu_ids: - self.netG_A.cuda(opt.gpu_ids[0]) - self.netG_B.cuda(opt.gpu_ids[0]) - self.mapping_net.cuda(opt.gpu_ids[0]) - - if not self.isTrain: - self.load_network(self.mapping_net, "mapping_net", opt.which_epoch) - - # Discriminator network - if self.isTrain: - use_sigmoid = opt.no_lsgan - netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc - if not opt.no_instance: - netD_input_nc += 1 - - self.netD = networks.define_D( - netD_input_nc, - opt.ndf, - opt.n_layers_D, - opt, - opt.norm, - use_sigmoid, - opt.num_D, - not opt.no_ganFeat_loss, - gpu_ids=self.gpu_ids, - ) - - # set loss functions and optimizers - if self.isTrain: - if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: - raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") - self.fake_pool = ImagePool(opt.pool_size) - self.old_lr = opt.lr - - # define loss functions - self.loss_filter = self.init_loss_filter( - not opt.no_ganFeat_loss, - not opt.no_vgg_loss, - opt.Smooth_L1, - opt.use_two_stage_mapping, - ) - - self.criterionGAN = networks.GANLoss( - use_lsgan=not opt.no_lsgan, tensor=self.Tensor - ) - - self.criterionFeat = torch.nn.L1Loss() - self.criterionFeat_feat = ( - torch.nn.L1Loss() if opt.use_l1_feat else torch.nn.MSELoss() - ) - - if self.opt.image_L1: - self.criterionImage = torch.nn.L1Loss() - else: - self.criterionImage = torch.nn.SmoothL1Loss() - - print(self.criterionFeat_feat) - if not opt.no_vgg_loss: - self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) - - # Names so we can breakout loss - self.loss_names = self.loss_filter( - "G_Feat_L2", - "G_GAN", - "G_GAN_Feat", - "G_VGG", - "D_real", - "D_fake", - "Smooth_L1", - "G_Feat_L2_Stage_1", - ) - - # initialize optimizers - # optimizer G - - if opt.no_TTUR: - beta1, beta2 = opt.beta1, 0.999 - G_lr, D_lr = opt.lr, opt.lr - else: - beta1, beta2 = 0, 0.9 - G_lr, D_lr = opt.lr / 2, opt.lr * 2 - - if not opt.no_load_VAE: - params = list(self.mapping_net.parameters()) - self.optimizer_mapping = torch.optim.Adam( - params, lr=G_lr, betas=(beta1, beta2) - ) - - # optimizer D - params = list(self.netD.parameters()) - self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2)) - - print("---------- Optimizers initialized -------------") - - def encode_input( - self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False - ): - if self.opt.label_nc == 0: - input_label = label_map.data.cuda() - else: - # create one-hot vector for label map - size = label_map.size() - oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) - input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() - input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) - if self.opt.data_type == 16: - input_label = input_label.half() - - # get edges from instance map - if not self.opt.no_instance: - inst_map = inst_map.data.cuda() - edge_map = self.get_edges(inst_map) - input_label = torch.cat((input_label, edge_map), dim=1) - input_label = Variable(input_label, volatile=infer) - - # real images for training - if real_image is not None: - real_image = Variable(real_image.data.cuda()) - - return input_label, inst_map, real_image, feat_map - - def discriminate(self, input_label, test_image, use_pool=False): - input_concat = torch.cat((input_label, test_image.detach()), dim=1) - if use_pool: - fake_query = self.fake_pool.query(input_concat) - return self.netD.forward(fake_query) - else: - return self.netD.forward(input_concat) - - def forward( - self, - label, - inst, - image, - feat, - pair=True, - infer=False, - last_label=None, - last_image=None, - ): - # Encode Inputs - input_label, inst_map, real_image, feat_map = self.encode_input( - label, inst, image, feat - ) - - # Fake Generation - input_concat = input_label - - label_feat = self.netG_A.forward(input_concat, flow="enc") - # print('label:') - # print(label_feat.min(), label_feat.max(), label_feat.mean()) - # label_feat = label_feat / 16.0 - - if self.opt.NL_use_mask: - label_feat_map = self.mapping_net(label_feat.detach(), inst) - else: - label_feat_map = self.mapping_net(label_feat.detach()) - - fake_image = self.netG_B.forward(label_feat_map, flow="dec") - image_feat = self.netG_B.forward(real_image, flow="enc") - - loss_feat_l2_stage_1 = 0 - loss_feat_l2 = ( - self.criterionFeat_feat(label_feat_map, image_feat.data) * self.opt.l2_feat - ) - - if self.opt.feat_gan: - # Fake Detection and Loss - pred_fake_pool = self.discriminate( - label_feat.detach(), label_feat_map, use_pool=True - ) - loss_D_fake = self.criterionGAN(pred_fake_pool, False) - - # Real Detection and Loss - pred_real = self.discriminate(label_feat.detach(), image_feat) - loss_D_real = self.criterionGAN(pred_real, True) - - # GAN loss (Fake Passability Loss) - pred_fake = self.netD.forward( - torch.cat((label_feat.detach(), label_feat_map), dim=1) - ) - loss_G_GAN = self.criterionGAN(pred_fake, True) - else: - # Fake Detection and Loss - pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) - loss_D_fake = self.criterionGAN(pred_fake_pool, False) - - # Real Detection and Loss - if pair: - pred_real = self.discriminate(input_label, real_image) - else: - pred_real = self.discriminate(last_label, last_image) - loss_D_real = self.criterionGAN(pred_real, True) - - # GAN loss (Fake Passability Loss) - pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) - loss_G_GAN = self.criterionGAN(pred_fake, True) - - # GAN feature matching loss - loss_G_GAN_Feat = 0 - if not self.opt.no_ganFeat_loss and pair: - feat_weights = 4.0 / (self.opt.n_layers_D + 1) - D_weights = 1.0 / self.opt.num_D - for i in range(self.opt.num_D): - for j in range(len(pred_fake[i]) - 1): - tmp = ( - self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) - * self.opt.lambda_feat - ) - loss_G_GAN_Feat += D_weights * feat_weights * tmp - else: - loss_G_GAN_Feat = torch.zeros(1).to(label.device) - - # VGG feature matching loss - loss_G_VGG = 0 - if not self.opt.no_vgg_loss: - loss_G_VGG = ( - self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat - if pair - else torch.zeros(1).to(label.device) - ) - - smooth_l1_loss = 0 - if self.opt.Smooth_L1: - smooth_l1_loss = ( - self.criterionImage(fake_image, real_image) * self.opt.L1_weight - ) - - return [ - self.loss_filter( - loss_feat_l2, - loss_G_GAN, - loss_G_GAN_Feat, - loss_G_VGG, - loss_D_real, - loss_D_fake, - smooth_l1_loss, - loss_feat_l2_stage_1, - ), - None if not infer else fake_image, - ] - - def inference(self, label, inst): - - use_gpu = torch.cuda.is_available() and len(self.opt.gpu_ids) > 0 - if use_gpu: - input_concat = label.data.cuda() - inst_data = inst.cuda() - else: - input_concat = label.data - inst_data = inst - - label_feat = self.netG_A.forward(input_concat, flow="enc") - - if self.opt.NL_use_mask: - if self.opt.inference_optimize: - label_feat_map = self.mapping_net.inference_forward( - label_feat.detach(), inst_data - ) - else: - label_feat_map = self.mapping_net(label_feat.detach(), inst_data) - else: - label_feat_map = self.mapping_net(label_feat.detach()) - - fake_image = self.netG_B.forward(label_feat_map, flow="dec") - return fake_image - - -class InferenceModel(Pix2PixHDModel_Mapping): - def forward(self, label, inst): - return self.inference(label, inst) diff --git a/Global/models/models.py b/Global/models/models.py deleted file mode 100644 index 69d4755..0000000 --- a/Global/models/models.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Microsoft Corporation - - -def create_model(opt): - assert opt.model == "pix2pixHD" - - from .pix2pixHD_model import Pix2PixHDModel, InferenceModel - - if opt.isTrain: - model = Pix2PixHDModel() - else: - model = InferenceModel() - - model.initialize(opt) - if opt.verbose: - print("model [%s] was created" % (model.name())) - - assert not opt.isTrain - - return model - - -def create_da_model(opt): - assert opt.model == "pix2pixHD" - - from .pix2pixHD_model_DA import Pix2PixHDModel, InferenceModel - - if opt.isTrain: - model = Pix2PixHDModel() - else: - model = InferenceModel() - - model.initialize(opt) - if opt.verbose: - print("model [%s] was created" % (model.name())) - - assert not opt.isTrain - - return model diff --git a/Global/models/networks.py b/Global/models/networks.py deleted file mode 100644 index 6c4b086..0000000 --- a/Global/models/networks.py +++ /dev/null @@ -1,875 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import torch -import torch.nn as nn -import functools -from torch.autograd import Variable -import numpy as np -from torch.nn.utils import spectral_norm - -# from util.util import SwitchNorm2d -import torch.nn.functional as F - -############################################################################### -# Functions -############################################################################### -def weights_init(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(0.0, 0.02) - elif classname.find("BatchNorm2d") != -1: - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) - - -def get_norm_layer(norm_type="instance"): - if norm_type == "batch": - norm_layer = functools.partial(nn.BatchNorm2d, affine=True) - elif norm_type == "instance": - norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) - elif norm_type == "spectral": - norm_layer = spectral_norm() - elif norm_type == "SwitchNorm": - norm_layer = SwitchNorm2d - else: - raise NotImplementedError("normalization layer [%s] is not found" % norm_type) - return norm_layer - - -def print_network(net): - if isinstance(net, list): - net = net[0] - num_params = 0 - for param in net.parameters(): - num_params += param.numel() - print(net) - print("Total number of parameters: %d" % num_params) - - -def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, - n_blocks_local=3, norm='instance', gpu_ids=[], opt=None): - - norm_layer = get_norm_layer(norm_type=norm) - if netG == 'global': - # if opt.self_gen: - if opt.use_v2: - netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt) - else: - netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt) - else: - raise('generator not implemented!') - print(netG) - if len(gpu_ids) > 0: - assert(torch.cuda.is_available()) - netG.cuda(gpu_ids[0]) - netG.apply(weights_init) - return netG - - -def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): - norm_layer = get_norm_layer(norm_type=norm) - netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) - print(netD) - if len(gpu_ids) > 0: - assert(torch.cuda.is_available()) - netD.cuda(gpu_ids[0]) - netD.apply(weights_init) - return netD - - - -class GlobalGenerator_DCDCv2(nn.Module): - def __init__( - self, - input_nc, - output_nc, - ngf=64, - k_size=3, - n_downsampling=8, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - opt=None, - ): - super(GlobalGenerator_DCDCv2, self).__init__() - activation = nn.ReLU(True) - - model = [ - nn.ReflectionPad2d(3), - nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0), - norm_layer(ngf), - activation, - ] - ### downsample - for i in range(opt.start_r): - mult = 2 ** i - model += [ - nn.Conv2d( - min(ngf * mult, opt.mc), - min(ngf * mult * 2, opt.mc), - kernel_size=k_size, - stride=2, - padding=1, - ), - norm_layer(min(ngf * mult * 2, opt.mc)), - activation, - ] - for i in range(opt.start_r, n_downsampling - 1): - mult = 2 ** i - model += [ - nn.Conv2d( - min(ngf * mult, opt.mc), - min(ngf * mult * 2, opt.mc), - kernel_size=k_size, - stride=2, - padding=1, - ), - norm_layer(min(ngf * mult * 2, opt.mc)), - activation, - ] - model += [ - ResnetBlock( - min(ngf * mult * 2, opt.mc), - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - ) - ] - model += [ - ResnetBlock( - min(ngf * mult * 2, opt.mc), - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - ) - ] - mult = 2 ** (n_downsampling - 1) - - if opt.spatio_size == 32: - model += [ - nn.Conv2d( - min(ngf * mult, opt.mc), - min(ngf * mult * 2, opt.mc), - kernel_size=k_size, - stride=2, - padding=1, - ), - norm_layer(min(ngf * mult * 2, opt.mc)), - activation, - ] - if opt.spatio_size == 64: - model += [ - ResnetBlock( - min(ngf * mult * 2, opt.mc), - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - ) - ] - model += [ - ResnetBlock( - min(ngf * mult * 2, opt.mc), - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - ) - ] - # model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)] - if opt.feat_dim > 0: - model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)] - self.encoder = nn.Sequential(*model) - - # decode - model = [] - if opt.feat_dim > 0: - model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)] - # model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)] - o_pad = 0 if k_size == 4 else 1 - mult = 2 ** n_downsampling - model += [ - ResnetBlock( - min(ngf * mult, opt.mc), - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - ) - ] - - if opt.spatio_size == 32: - model += [ - nn.ConvTranspose2d( - min(ngf * mult, opt.mc), - min(int(ngf * mult / 2), opt.mc), - kernel_size=k_size, - stride=2, - padding=1, - output_padding=o_pad, - ), - norm_layer(min(int(ngf * mult / 2), opt.mc)), - activation, - ] - if opt.spatio_size == 64: - model += [ - ResnetBlock( - min(ngf * mult, opt.mc), - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - ) - ] - - for i in range(1, n_downsampling - opt.start_r): - mult = 2 ** (n_downsampling - i) - model += [ - ResnetBlock( - min(ngf * mult, opt.mc), - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - ) - ] - model += [ - ResnetBlock( - min(ngf * mult, opt.mc), - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - opt=opt, - ) - ] - model += [ - nn.ConvTranspose2d( - min(ngf * mult, opt.mc), - min(int(ngf * mult / 2), opt.mc), - kernel_size=k_size, - stride=2, - padding=1, - output_padding=o_pad, - ), - norm_layer(min(int(ngf * mult / 2), opt.mc)), - activation, - ] - for i in range(n_downsampling - opt.start_r, n_downsampling): - mult = 2 ** (n_downsampling - i) - model += [ - nn.ConvTranspose2d( - min(ngf * mult, opt.mc), - min(int(ngf * mult / 2), opt.mc), - kernel_size=k_size, - stride=2, - padding=1, - output_padding=o_pad, - ), - norm_layer(min(int(ngf * mult / 2), opt.mc)), - activation, - ] - if opt.use_segmentation_model: - model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)] - else: - model += [ - nn.ReflectionPad2d(3), - nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0), - nn.Tanh(), - ] - self.decoder = nn.Sequential(*model) - - def forward(self, input, flow="enc_dec"): - if flow == "enc": - return self.encoder(input) - elif flow == "dec": - return self.decoder(input) - elif flow == "enc_dec": - x = self.encoder(input) - x = self.decoder(x) - return x - - -# Define a resnet block -class ResnetBlock(nn.Module): - def __init__( - self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1 - ): - super(ResnetBlock, self).__init__() - self.opt = opt - self.dilation = dilation - self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) - - def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): - conv_block = [] - p = 0 - if padding_type == "reflect": - conv_block += [nn.ReflectionPad2d(self.dilation)] - elif padding_type == "replicate": - conv_block += [nn.ReplicationPad2d(self.dilation)] - elif padding_type == "zero": - p = self.dilation - else: - raise NotImplementedError("padding [%s] is not implemented" % padding_type) - - conv_block += [ - nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation), - norm_layer(dim), - activation, - ] - if use_dropout: - conv_block += [nn.Dropout(0.5)] - - p = 0 - if padding_type == "reflect": - conv_block += [nn.ReflectionPad2d(1)] - elif padding_type == "replicate": - conv_block += [nn.ReplicationPad2d(1)] - elif padding_type == "zero": - p = 1 - else: - raise NotImplementedError("padding [%s] is not implemented" % padding_type) - conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)] - - return nn.Sequential(*conv_block) - - def forward(self, x): - out = x + self.conv_block(x) - return out - - -class Encoder(nn.Module): - def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): - super(Encoder, self).__init__() - self.output_nc = output_nc - - model = [ - nn.ReflectionPad2d(3), - nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), - norm_layer(ngf), - nn.ReLU(True), - ] - ### downsample - for i in range(n_downsampling): - mult = 2 ** i - model += [ - nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), - norm_layer(ngf * mult * 2), - nn.ReLU(True), - ] - - ### upsample - for i in range(n_downsampling): - mult = 2 ** (n_downsampling - i) - model += [ - nn.ConvTranspose2d( - ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1 - ), - norm_layer(int(ngf * mult / 2)), - nn.ReLU(True), - ] - - model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] - self.model = nn.Sequential(*model) - - def forward(self, input, inst): - outputs = self.model(input) - - # instance-wise average pooling - outputs_mean = outputs.clone() - inst_list = np.unique(inst.cpu().numpy().astype(int)) - for i in inst_list: - for b in range(input.size()[0]): - indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4 - for j in range(self.output_nc): - output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] - mean_feat = torch.mean(output_ins).expand_as(output_ins) - outputs_mean[ - indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3] - ] = mean_feat - return outputs_mean - - -def SN(module, mode=True): - if mode: - return torch.nn.utils.spectral_norm(module) - - return module - - -class NonLocalBlock2D_with_mask_Res(nn.Module): - def __init__( - self, - in_channels, - inter_channels, - mode="add", - re_norm=False, - temperature=1.0, - use_self=False, - cosin=False, - ): - super(NonLocalBlock2D_with_mask_Res, self).__init__() - - self.cosin = cosin - self.renorm = re_norm - self.in_channels = in_channels - self.inter_channels = inter_channels - - self.g = nn.Conv2d( - in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 - ) - - self.W = nn.Conv2d( - in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 - ) - # for pytorch 0.3.1 - # nn.init.constant(self.W.weight, 0) - # nn.init.constant(self.W.bias, 0) - # for pytorch 0.4.0 - nn.init.constant_(self.W.weight, 0) - nn.init.constant_(self.W.bias, 0) - self.theta = nn.Conv2d( - in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 - ) - - self.phi = nn.Conv2d( - in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 - ) - - self.mode = mode - self.temperature = temperature - self.use_self = use_self - - norm_layer = get_norm_layer(norm_type="instance") - activation = nn.ReLU(True) - - model = [] - for i in range(3): - model += [ - ResnetBlock( - inter_channels, - padding_type="reflect", - activation=activation, - norm_layer=norm_layer, - opt=None, - ) - ] - self.res_block = nn.Sequential(*model) - - def forward(self, x, mask): ## The shape of mask is Batch*1*H*W - batch_size = x.size(0) - - g_x = self.g(x).view(batch_size, self.inter_channels, -1) - - g_x = g_x.permute(0, 2, 1) - - theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) - - theta_x = theta_x.permute(0, 2, 1) - - phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) - - if self.cosin: - theta_x = F.normalize(theta_x, dim=2) - phi_x = F.normalize(phi_x, dim=1) - - f = torch.matmul(theta_x, phi_x) - - f /= self.temperature - - f_div_C = F.softmax(f, dim=2) - - tmp = 1 - mask - mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") - mask[mask > 0] = 1.0 - mask = 1 - mask - - tmp = F.interpolate(tmp, (x.size(2), x.size(3))) - mask *= tmp - - mask_expand = mask.view(batch_size, 1, -1) - mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1) - - # mask = 1 - mask - # mask=F.interpolate(mask,(x.size(2),x.size(3))) - # mask_expand=mask.view(batch_size,1,-1) - # mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1) - - if self.use_self: - mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0 - - # print(mask_expand.shape) - # print(f_div_C.shape) - - f_div_C = mask_expand * f_div_C - if self.renorm: - f_div_C = F.normalize(f_div_C, p=1, dim=2) - - ########################### - - y = torch.matmul(f_div_C, g_x) - - y = y.permute(0, 2, 1).contiguous() - - y = y.view(batch_size, self.inter_channels, *x.size()[2:]) - W_y = self.W(y) - - W_y = self.res_block(W_y) - - if self.mode == "combine": - full_mask = mask.repeat(1, self.inter_channels, 1, 1) - z = full_mask * x + (1 - full_mask) * W_y - return z - - -class MultiscaleDiscriminator(nn.Module): - def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, - use_sigmoid=False, num_D=3, getIntermFeat=False): - super(MultiscaleDiscriminator, self).__init__() - self.num_D = num_D - self.n_layers = n_layers - self.getIntermFeat = getIntermFeat - - for i in range(num_D): - netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) - if getIntermFeat: - for j in range(n_layers+2): - setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) - else: - setattr(self, 'layer'+str(i), netD.model) - - self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) - - def singleD_forward(self, model, input): - if self.getIntermFeat: - result = [input] - for i in range(len(model)): - result.append(model[i](result[-1])) - return result[1:] - else: - return [model(input)] - - def forward(self, input): - num_D = self.num_D - result = [] - input_downsampled = input - for i in range(num_D): - if self.getIntermFeat: - model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] - else: - model = getattr(self, 'layer'+str(num_D-1-i)) - result.append(self.singleD_forward(model, input_downsampled)) - if i != (num_D-1): - input_downsampled = self.downsample(input_downsampled) - return result - -# Defines the PatchGAN discriminator with the specified arguments. -class NLayerDiscriminator(nn.Module): - def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): - super(NLayerDiscriminator, self).__init__() - self.getIntermFeat = getIntermFeat - self.n_layers = n_layers - - kw = 4 - padw = int(np.ceil((kw-1.0)/2)) - sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]] - - nf = ndf - for n in range(1, n_layers): - nf_prev = nf - nf = min(nf * 2, 512) - sequence += [[ - SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), - norm_layer(nf), nn.LeakyReLU(0.2, True) - ]] - - nf_prev = nf - nf = min(nf * 2, 512) - sequence += [[ - SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN), - norm_layer(nf), - nn.LeakyReLU(0.2, True) - ]] - - sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]] - - if use_sigmoid: - sequence += [[nn.Sigmoid()]] - - if getIntermFeat: - for n in range(len(sequence)): - setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) - else: - sequence_stream = [] - for n in range(len(sequence)): - sequence_stream += sequence[n] - self.model = nn.Sequential(*sequence_stream) - - def forward(self, input): - if self.getIntermFeat: - res = [input] - for n in range(self.n_layers+2): - model = getattr(self, 'model'+str(n)) - res.append(model(res[-1])) - return res[1:] - else: - return self.model(input) - - - -class Patch_Attention_4(nn.Module): ## While combine the feature map, use conv and mask - def __init__(self, in_channels, inter_channels, patch_size): - super(Patch_Attention_4, self).__init__() - - self.patch_size=patch_size - - - # self.g = nn.Conv2d( - # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 - # ) - - # self.W = nn.Conv2d( - # in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 - # ) - # # for pytorch 0.3.1 - # # nn.init.constant(self.W.weight, 0) - # # nn.init.constant(self.W.bias, 0) - # # for pytorch 0.4.0 - # nn.init.constant_(self.W.weight, 0) - # nn.init.constant_(self.W.bias, 0) - # self.theta = nn.Conv2d( - # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 - # ) - - # self.phi = nn.Conv2d( - # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 - # ) - - self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True) - norm_layer = get_norm_layer(norm_type="instance") - activation = nn.ReLU(True) - - model = [] - for i in range(1): - model += [ - ResnetBlock( - inter_channels, - padding_type="reflect", - activation=activation, - norm_layer=norm_layer, - opt=None, - ) - ] - self.res_block = nn.Sequential(*model) - - def Hard_Compose(self, input, dim, index): - # batch index select - # input: [B,C,HW] - # dim: scalar > 0 - # index: [B, HW] - views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))] - expanse = list(input.size()) - expanse[0] = -1 - expanse[dim] = -1 - index = index.view(views).expand(expanse) - return torch.gather(input, dim, index) - - def forward(self, z, mask): ## The shape of mask is Batch*1*H*W - - x=self.res_block(z) - - b,c,h,w=x.shape - - ## mask resize + dilation - # tmp = 1 - mask - mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") - mask[mask > 0] = 1.0 - - # mask = 1 - mask - # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) - # mask *= tmp - # mask=1-mask - ## 1: mask position 0: non-mask - - mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) - non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float() - all_patch_num=h*w/self.patch_size/self.patch_size - non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1) - - x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) - y_unfold=x_unfold.permute(0,2,1) - x_unfold_normalized=F.normalize(x_unfold,dim=1) - y_unfold_normalized=F.normalize(y_unfold,dim=2) - correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized) - correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9) - correlation_matrix=F.softmax(correlation_matrix,dim=2) - - # print(correlation_matrix) - - R, max_arg=torch.max(correlation_matrix,dim=2) - - composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg) - composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) - - concat_1=torch.cat((z,composed_fold,mask),dim=1) - concat_1=self.F_Combine(concat_1) - - return concat_1 - - def inference_forward(self,z,mask): ## Reduce the extra memory cost - - - x=self.res_block(z) - - b,c,h,w=x.shape - - ## mask resize + dilation - # tmp = 1 - mask - mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") - mask[mask > 0] = 1.0 - # mask = 1 - mask - # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) - # mask *= tmp - # mask=1-mask - ## 1: mask position 0: non-mask - - mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) - non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num - - all_patch_num=h*w/self.patch_size/self.patch_size - - mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0] - - - if len(mask_index)==0: ## No mask patch is selected, no attention is needed - - composed_fold=x - - else: - - unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0] - - x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) - - Query_Patch=torch.index_select(x_unfold,2,mask_index) - Key_Patch=torch.index_select(x_unfold,2,unmask_index) - - Query_Patch=Query_Patch.permute(0,2,1) - Query_Patch_normalized=F.normalize(Query_Patch,dim=2) - Key_Patch_normalized=F.normalize(Key_Patch,dim=1) - - correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized) - correlation_matrix=F.softmax(correlation_matrix,dim=2) - - - R, max_arg=torch.max(correlation_matrix,dim=2) - - composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg) - x_unfold[:,:,mask_index]=composed_unfold - composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) - - concat_1=torch.cat((z,composed_fold,mask),dim=1) - concat_1=self.F_Combine(concat_1) - - - return concat_1 - -############################################################################## -# Losses -############################################################################## -class GANLoss(nn.Module): - def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, - tensor=torch.FloatTensor): - super(GANLoss, self).__init__() - self.real_label = target_real_label - self.fake_label = target_fake_label - self.real_label_var = None - self.fake_label_var = None - self.Tensor = tensor - if use_lsgan: - self.loss = nn.MSELoss() - else: - self.loss = nn.BCELoss() - - def get_target_tensor(self, input, target_is_real): - target_tensor = None - if target_is_real: - create_label = ((self.real_label_var is None) or - (self.real_label_var.numel() != input.numel())) - if create_label: - real_tensor = self.Tensor(input.size()).fill_(self.real_label) - self.real_label_var = Variable(real_tensor, requires_grad=False) - target_tensor = self.real_label_var - else: - create_label = ((self.fake_label_var is None) or - (self.fake_label_var.numel() != input.numel())) - if create_label: - fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) - self.fake_label_var = Variable(fake_tensor, requires_grad=False) - target_tensor = self.fake_label_var - return target_tensor - - def __call__(self, input, target_is_real): - if isinstance(input[0], list): - loss = 0 - for input_i in input: - pred = input_i[-1] - target_tensor = self.get_target_tensor(pred, target_is_real) - loss += self.loss(pred, target_tensor) - return loss - else: - target_tensor = self.get_target_tensor(input[-1], target_is_real) - return self.loss(input[-1], target_tensor) - - - - -####################################### VGG Loss - -from torchvision import models -class VGG19_torch(torch.nn.Module): - def __init__(self, requires_grad=False): - super(VGG19_torch, self).__init__() - vgg_pretrained_features = models.vgg19(pretrained=True).features - self.slice1 = torch.nn.Sequential() - self.slice2 = torch.nn.Sequential() - self.slice3 = torch.nn.Sequential() - self.slice4 = torch.nn.Sequential() - self.slice5 = torch.nn.Sequential() - for x in range(2): - self.slice1.add_module(str(x), vgg_pretrained_features[x]) - for x in range(2, 7): - self.slice2.add_module(str(x), vgg_pretrained_features[x]) - for x in range(7, 12): - self.slice3.add_module(str(x), vgg_pretrained_features[x]) - for x in range(12, 21): - self.slice4.add_module(str(x), vgg_pretrained_features[x]) - for x in range(21, 30): - self.slice5.add_module(str(x), vgg_pretrained_features[x]) - if not requires_grad: - for param in self.parameters(): - param.requires_grad = False - - def forward(self, X): - h_relu1 = self.slice1(X) - h_relu2 = self.slice2(h_relu1) - h_relu3 = self.slice3(h_relu2) - h_relu4 = self.slice4(h_relu3) - h_relu5 = self.slice5(h_relu4) - out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] - return out - -class VGGLoss_torch(nn.Module): - def __init__(self, gpu_ids): - super(VGGLoss_torch, self).__init__() - self.vgg = VGG19_torch().cuda() - self.criterion = nn.L1Loss() - self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] - - def forward(self, x, y): - x_vgg, y_vgg = self.vgg(x), self.vgg(y) - loss = 0 - for i in range(len(x_vgg)): - loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) - return loss \ No newline at end of file diff --git a/Global/models/pix2pixHD_model.py b/Global/models/pix2pixHD_model.py deleted file mode 100644 index edf829f..0000000 --- a/Global/models/pix2pixHD_model.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import numpy as np -import torch -import os -from torch.autograd import Variable -from util.image_pool import ImagePool -from .base_model import BaseModel -from . import networks - -class Pix2PixHDModel(BaseModel): - def name(self): - return 'Pix2PixHDModel' - - def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_L1): - flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True,use_smooth_L1) - def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake,smooth_l1): - return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg, g_kl, d_real,d_fake,smooth_l1),flags) if f] - return loss_filter - - def initialize(self, opt): - BaseModel.initialize(self, opt) - if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM - torch.backends.cudnn.benchmark = True - self.isTrain = opt.isTrain - self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false - self.gen_features = self.use_features and not self.opt.load_features ## it is also false - input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel # - - ##### define networks - # Generator network - netG_input_nc = input_nc - if not opt.no_instance: - netG_input_nc += 1 - if self.use_features: - netG_input_nc += opt.feat_num - self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size, - opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, - opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt) - - # Discriminator network - if self.isTrain: - use_sigmoid = opt.no_lsgan - netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc - if not opt.no_instance: - netD_input_nc += 1 - self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, - opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) - - if self.opt.verbose: - print('---------- Networks initialized -------------') - - # load networks - if not self.isTrain or opt.continue_train or opt.load_pretrain: - pretrained_path = '' if not self.isTrain else opt.load_pretrain - self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) - - print("---------- G Networks reloaded -------------") - if self.isTrain: - self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) - print("---------- D Networks reloaded -------------") - - - if self.gen_features: - self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) - - # set loss functions and optimizers - if self.isTrain: - if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0! - raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") - self.fake_pool = ImagePool(opt.pool_size) - self.old_lr = opt.lr - - # define loss functions - self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1) - - self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) - self.criterionFeat = torch.nn.L1Loss() - - # self.criterionImage = torch.nn.SmoothL1Loss() - if not opt.no_vgg_loss: - self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) - - - self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG', 'G_KL', 'D_real', 'D_fake', 'Smooth_L1') - - # initialize optimizers - # optimizer G - params = list(self.netG.parameters()) - if self.gen_features: - params += list(self.netE.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) - - # optimizer D - params = list(self.netD.parameters()) - self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) - - print("---------- Optimizers initialized -------------") - - if opt.continue_train: - self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch) - self.load_optimizer(self.optimizer_G, "G", opt.which_epoch) - for param_groups in self.optimizer_D.param_groups: - self.old_lr=param_groups['lr'] - - print("---------- Optimizers reloaded -------------") - print("---------- Current LR is %.8f -------------"%(self.old_lr)) - - ## We also want to re-load the parameters of optimizer. - - def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): - if self.opt.label_nc == 0: - input_label = label_map.data.cuda() - else: - # create one-hot vector for label map - size = label_map.size() - oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) - input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() - input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) - if self.opt.data_type == 16: - input_label = input_label.half() - - # get edges from instance map - if not self.opt.no_instance: - inst_map = inst_map.data.cuda() - edge_map = self.get_edges(inst_map) - input_label = torch.cat((input_label, edge_map), dim=1) - input_label = Variable(input_label, volatile=infer) - - # real images for training - if real_image is not None: - real_image = Variable(real_image.data.cuda()) - - # instance map for feature encoding - if self.use_features: - # get precomputed feature maps - if self.opt.load_features: - feat_map = Variable(feat_map.data.cuda()) - if self.opt.label_feat: - inst_map = label_map.cuda() - - return input_label, inst_map, real_image, feat_map - - def discriminate(self, input_label, test_image, use_pool=False): - if input_label is None: - input_concat = test_image.detach() - else: - input_concat = torch.cat((input_label, test_image.detach()), dim=1) - if use_pool: - fake_query = self.fake_pool.query(input_concat) - return self.netD.forward(fake_query) - else: - return self.netD.forward(input_concat) - - def forward(self, label, inst, image, feat, infer=False): - # Encode Inputs - input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) - - # Fake Generation - if self.use_features: - if not self.opt.load_features: - feat_map = self.netE.forward(real_image, inst_map) - input_concat = torch.cat((input_label, feat_map), dim=1) - else: - input_concat = input_label - hiddens = self.netG.forward(input_concat, 'enc') - noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) - # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. - # We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py) - fake_image = self.netG.forward(hiddens + noise, 'dec') - - if self.opt.no_cgan: - # Fake Detection and Loss - pred_fake_pool = self.discriminate(None, fake_image, use_pool=True) - loss_D_fake = self.criterionGAN(pred_fake_pool, False) - - # Real Detection and Loss - pred_real = self.discriminate(None, real_image) - loss_D_real = self.criterionGAN(pred_real, True) - - # GAN loss (Fake Passability Loss) - pred_fake = self.netD.forward(fake_image) - loss_G_GAN = self.criterionGAN(pred_fake, True) - else: - # Fake Detection and Loss - pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) - loss_D_fake = self.criterionGAN(pred_fake_pool, False) - - # Real Detection and Loss - pred_real = self.discriminate(input_label, real_image) - loss_D_real = self.criterionGAN(pred_real, True) - - # GAN loss (Fake Passability Loss) - pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) - loss_G_GAN = self.criterionGAN(pred_fake, True) - - - loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl - - # GAN feature matching loss - loss_G_GAN_Feat = 0 - if not self.opt.no_ganFeat_loss: - feat_weights = 4.0 / (self.opt.n_layers_D + 1) - D_weights = 1.0 / self.opt.num_D - for i in range(self.opt.num_D): - for j in range(len(pred_fake[i])-1): - loss_G_GAN_Feat += D_weights * feat_weights * \ - self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat - - # VGG feature matching loss - loss_G_VGG = 0 - if not self.opt.no_vgg_loss: - loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat - - - smooth_l1_loss=0 - - return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,smooth_l1_loss ), None if not infer else fake_image ] - - def inference(self, label, inst, image=None, feat=None): - # Encode Inputs - image = Variable(image) if image is not None else None - input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) - - # Fake Generation - if self.use_features: - if self.opt.use_encoded_image: - # encode the real image to get feature map - feat_map = self.netE.forward(real_image, inst_map) - else: - # sample clusters from precomputed features - feat_map = self.sample_features(inst_map) - input_concat = torch.cat((input_label, feat_map), dim=1) - else: - input_concat = input_label - - if torch.__version__.startswith('0.4'): - with torch.no_grad(): - fake_image = self.netG.forward(input_concat) - else: - fake_image = self.netG.forward(input_concat) - return fake_image - - def sample_features(self, inst): - # read precomputed feature clusters - cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) - features_clustered = np.load(cluster_path, encoding='latin1').item() - - # randomly sample from the feature clusters - inst_np = inst.cpu().numpy().astype(int) - feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) - for i in np.unique(inst_np): - label = i if i < 1000 else i//1000 - if label in features_clustered: - feat = features_clustered[label] - cluster_idx = np.random.randint(0, feat.shape[0]) - - idx = (inst == int(i)).nonzero() - for k in range(self.opt.feat_num): - feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] - if self.opt.data_type==16: - feat_map = feat_map.half() - return feat_map - - def encode_features(self, image, inst): - image = Variable(image.cuda(), volatile=True) - feat_num = self.opt.feat_num - h, w = inst.size()[2], inst.size()[3] - block_num = 32 - feat_map = self.netE.forward(image, inst.cuda()) - inst_np = inst.cpu().numpy().astype(int) - feature = {} - for i in range(self.opt.label_nc): - feature[i] = np.zeros((0, feat_num+1)) - for i in np.unique(inst_np): - label = i if i < 1000 else i//1000 - idx = (inst == int(i)).nonzero() - num = idx.size()[0] - idx = idx[num//2,:] - val = np.zeros((1, feat_num+1)) - for k in range(feat_num): - val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] - val[0, feat_num] = float(num) / (h * w // block_num) - feature[label] = np.append(feature[label], val, axis=0) - return feature - - def get_edges(self, t): - edge = torch.cuda.ByteTensor(t.size()).zero_() - edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) - edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) - edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) - edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) - if self.opt.data_type==16: - return edge.half() - else: - return edge.float() - - def save(self, which_epoch): - self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) - self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) - - self.save_optimizer(self.optimizer_G,"G",which_epoch) - self.save_optimizer(self.optimizer_D,"D",which_epoch) - - if self.gen_features: - self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) - - def update_fixed_params(self): - - params = list(self.netG.parameters()) - if self.gen_features: - params += list(self.netE.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) - if self.opt.verbose: - print('------------ Now also finetuning global generator -----------') - - def update_learning_rate(self): - lrd = self.opt.lr / self.opt.niter_decay - lr = self.old_lr - lrd - for param_group in self.optimizer_D.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_G.param_groups: - param_group['lr'] = lr - if self.opt.verbose: - print('update learning rate: %f -> %f' % (self.old_lr, lr)) - self.old_lr = lr - - -class InferenceModel(Pix2PixHDModel): - def forward(self, inp): - label, inst = inp - return self.inference(label, inst) diff --git a/Global/models/pix2pixHD_model_DA.py b/Global/models/pix2pixHD_model_DA.py deleted file mode 100644 index 617589d..0000000 --- a/Global/models/pix2pixHD_model_DA.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import numpy as np -import torch -import os -from torch.autograd import Variable -from util.image_pool import ImagePool -from .base_model import BaseModel -from . import networks - - -class Pix2PixHDModel(BaseModel): - def name(self): - return 'Pix2PixHDModel' - - def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): - flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True, True) - - def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake): - return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake), flags) if f] - - return loss_filter - - def initialize(self, opt): - BaseModel.initialize(self, opt) - if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM - torch.backends.cudnn.benchmark = True - self.isTrain = opt.isTrain - self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false - self.gen_features = self.use_features and not self.opt.load_features ## it is also false - input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel # - - ##### define networks - # Generator network - netG_input_nc = input_nc - if not opt.no_instance: - netG_input_nc += 1 - if self.use_features: - netG_input_nc += opt.feat_num - self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size, - opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, - opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt) - - # Discriminator network - if self.isTrain: - use_sigmoid = opt.no_lsgan - netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc - if not opt.no_instance: - netD_input_nc += 1 - self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt,opt.norm, use_sigmoid, - opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) - - self.feat_D=networks.define_D(64, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, - 1, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) - - if self.opt.verbose: - print('---------- Networks initialized -------------') - - # load networks - if not self.isTrain or opt.continue_train or opt.load_pretrain: - pretrained_path = '' if not self.isTrain else opt.load_pretrain - self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) - - print("---------- G Networks reloaded -------------") - if self.isTrain: - self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) - self.load_network(self.feat_D, 'feat_D', opt.which_epoch, pretrained_path) - print("---------- D Networks reloaded -------------") - - - # set loss functions and optimizers - if self.isTrain: - if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0! - raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") - self.fake_pool = ImagePool(opt.pool_size) - self.old_lr = opt.lr - - # define loss functions - self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) - - self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) - self.criterionFeat = torch.nn.L1Loss() - if not opt.no_vgg_loss: - self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) - - # Names so we can breakout loss - self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_KL', 'D_real', 'D_fake', 'G_featD', 'featD_real','featD_fake') - - # initialize optimizers - # optimizer G - params = list(self.netG.parameters()) - if self.gen_features: - params += list(self.netE.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) - - # optimizer D - params = list(self.netD.parameters()) - self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) - - params = list(self.feat_D.parameters()) - self.optimizer_featD = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) - - print("---------- Optimizers initialized -------------") - - if opt.continue_train: - self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch) - self.load_optimizer(self.optimizer_G, "G", opt.which_epoch) - self.load_optimizer(self.optimizer_featD,'featD',opt.which_epoch) - for param_groups in self.optimizer_D.param_groups: - self.old_lr = param_groups['lr'] - - print("---------- Optimizers reloaded -------------") - print("---------- Current LR is %.8f -------------" % (self.old_lr)) - - ## We also want to re-load the parameters of optimizer. - - def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): - if self.opt.label_nc == 0: - input_label = label_map.data.cuda() - else: - # create one-hot vector for label map - size = label_map.size() - oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) - input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() - input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) - if self.opt.data_type == 16: - input_label = input_label.half() - - # get edges from instance map - if not self.opt.no_instance: - inst_map = inst_map.data.cuda() - edge_map = self.get_edges(inst_map) - input_label = torch.cat((input_label, edge_map), dim=1) - input_label = Variable(input_label, volatile=infer) - - # real images for training - if real_image is not None: - real_image = Variable(real_image.data.cuda()) - - # instance map for feature encoding - if self.use_features: - # get precomputed feature maps - if self.opt.load_features: - feat_map = Variable(feat_map.data.cuda()) - if self.opt.label_feat: - inst_map = label_map.cuda() - - return input_label, inst_map, real_image, feat_map - - def discriminate(self, input_label, test_image, use_pool=False): - if input_label is None: - input_concat = test_image.detach() - else: - input_concat = torch.cat((input_label, test_image.detach()), dim=1) - if use_pool: - fake_query = self.fake_pool.query(input_concat) - return self.netD.forward(fake_query) - else: - return self.netD.forward(input_concat) - - def feat_discriminate(self,input): - - return self.feat_D.forward(input.detach()) - - - def forward(self, label, inst, image, feat, infer=False): - # Encode Inputs - input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) - - # Fake Generation - if self.use_features: - if not self.opt.load_features: - feat_map = self.netE.forward(real_image, inst_map) - input_concat = torch.cat((input_label, feat_map), dim=1) - else: - input_concat = input_label - hiddens = self.netG.forward(input_concat, 'enc') - noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) - # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. - # We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py) - fake_image = self.netG.forward(hiddens + noise, 'dec') - - #################### - ##### GAN for the intermediate feature - real_old_feat =[] - syn_feat = [] - for index,x in enumerate(inst): - if x==1: - real_old_feat.append(hiddens[index].unsqueeze(0)) - else: - syn_feat.append(hiddens[index].unsqueeze(0)) - L=min(len(real_old_feat),len(syn_feat)) - real_old_feat=real_old_feat[:L] - syn_feat=syn_feat[:L] - real_old_feat=torch.cat(real_old_feat,0) - syn_feat=torch.cat(syn_feat,0) - - pred_fake_feat=self.feat_discriminate(real_old_feat) - loss_featD_fake = self.criterionGAN(pred_fake_feat, False) - pred_real_feat=self.feat_discriminate(syn_feat) - loss_featD_real = self.criterionGAN(pred_real_feat, True) - - pred_fake_feat_G=self.feat_D.forward(real_old_feat) - loss_G_featD=self.criterionGAN(pred_fake_feat_G,True) - - - ##################################### - if self.opt.no_cgan: - # Fake Detection and Loss - pred_fake_pool = self.discriminate(None, fake_image, use_pool=True) - loss_D_fake = self.criterionGAN(pred_fake_pool, False) - - # Real Detection and Loss - pred_real = self.discriminate(None, real_image) - loss_D_real = self.criterionGAN(pred_real, True) - - # GAN loss (Fake Passability Loss) - pred_fake = self.netD.forward(fake_image) - loss_G_GAN = self.criterionGAN(pred_fake, True) - else: - # Fake Detection and Loss - pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) - loss_D_fake = self.criterionGAN(pred_fake_pool, False) - - # Real Detection and Loss - pred_real = self.discriminate(input_label, real_image) - loss_D_real = self.criterionGAN(pred_real, True) - - # GAN loss (Fake Passability Loss) - pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) - loss_G_GAN = self.criterionGAN(pred_fake, True) - - loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl - - # GAN feature matching loss - loss_G_GAN_Feat = 0 - if not self.opt.no_ganFeat_loss: - feat_weights = 4.0 / (self.opt.n_layers_D + 1) - D_weights = 1.0 / self.opt.num_D - for i in range(self.opt.num_D): - for j in range(len(pred_fake[i]) - 1): - loss_G_GAN_Feat += D_weights * feat_weights * \ - self.criterionFeat(pred_fake[i][j], - pred_real[i][j].detach()) * self.opt.lambda_feat - - # VGG feature matching loss - loss_G_VGG = 0 - if not self.opt.no_vgg_loss: - loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat - - # Only return the fake_B image if necessary to save BW - return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,loss_G_featD, loss_featD_real, loss_featD_fake), - None if not infer else fake_image] - - def inference(self, label, inst, image=None, feat=None): - # Encode Inputs - image = Variable(image) if image is not None else None - input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) - - # Fake Generation - if self.use_features: - if self.opt.use_encoded_image: - # encode the real image to get feature map - feat_map = self.netE.forward(real_image, inst_map) - else: - # sample clusters from precomputed features - feat_map = self.sample_features(inst_map) - input_concat = torch.cat((input_label, feat_map), dim=1) - else: - input_concat = input_label - - if torch.__version__.startswith('0.4'): - with torch.no_grad(): - fake_image = self.netG.forward(input_concat) - else: - fake_image = self.netG.forward(input_concat) - return fake_image - - def sample_features(self, inst): - # read precomputed feature clusters - cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) - features_clustered = np.load(cluster_path, encoding='latin1').item() - - # randomly sample from the feature clusters - inst_np = inst.cpu().numpy().astype(int) - feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) - for i in np.unique(inst_np): - label = i if i < 1000 else i // 1000 - if label in features_clustered: - feat = features_clustered[label] - cluster_idx = np.random.randint(0, feat.shape[0]) - - idx = (inst == int(i)).nonzero() - for k in range(self.opt.feat_num): - feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k] - if self.opt.data_type == 16: - feat_map = feat_map.half() - return feat_map - - def encode_features(self, image, inst): - image = Variable(image.cuda(), volatile=True) - feat_num = self.opt.feat_num - h, w = inst.size()[2], inst.size()[3] - block_num = 32 - feat_map = self.netE.forward(image, inst.cuda()) - inst_np = inst.cpu().numpy().astype(int) - feature = {} - for i in range(self.opt.label_nc): - feature[i] = np.zeros((0, feat_num + 1)) - for i in np.unique(inst_np): - label = i if i < 1000 else i // 1000 - idx = (inst == int(i)).nonzero() - num = idx.size()[0] - idx = idx[num // 2, :] - val = np.zeros((1, feat_num + 1)) - for k in range(feat_num): - val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] - val[0, feat_num] = float(num) / (h * w // block_num) - feature[label] = np.append(feature[label], val, axis=0) - return feature - - def get_edges(self, t): - edge = torch.cuda.ByteTensor(t.size()).zero_() - edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) - edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) - edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) - edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) - if self.opt.data_type == 16: - return edge.half() - else: - return edge.float() - - def save(self, which_epoch): - self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) - self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) - self.save_network(self.feat_D,'featD',which_epoch,self.gpu_ids) - - self.save_optimizer(self.optimizer_G, "G", which_epoch) - self.save_optimizer(self.optimizer_D, "D", which_epoch) - self.save_optimizer(self.optimizer_featD,'featD',which_epoch) - - if self.gen_features: - self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) - - def update_fixed_params(self): - - params = list(self.netG.parameters()) - if self.gen_features: - params += list(self.netE.parameters()) - self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) - if self.opt.verbose: - print('------------ Now also finetuning global generator -----------') - - def update_learning_rate(self): - lrd = self.opt.lr / self.opt.niter_decay - lr = self.old_lr - lrd - for param_group in self.optimizer_D.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_G.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_featD.param_groups: - param_group['lr'] = lr - if self.opt.verbose: - print('update learning rate: %f -> %f' % (self.old_lr, lr)) - self.old_lr = lr - - -class InferenceModel(Pix2PixHDModel): - def forward(self, inp): - label, inst = inp - return self.inference(label, inst) diff --git a/Global/options/__init__.py b/Global/options/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/Global/options/base_options.py b/Global/options/base_options.py deleted file mode 100644 index 5917e47..0000000 --- a/Global/options/base_options.py +++ /dev/null @@ -1,469 +0,0 @@ -# Copyright (c) Microsoft Corporation - -import argparse -import torch - - -class BaseOptions: - def __init__(self): - self.parser = argparse.ArgumentParser() - self.initialized = False - - def initialize(self): - # experiment specifics - self.parser.add_argument( - "--name", - type=str, - default="label2city", - help="name of the experiment. It decides where to store samples and models", - ) - self.parser.add_argument( - "--gpu_ids", - type=str, - default="0", - help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU", - ) - self.parser.add_argument( - "--checkpoints_dir", - type=str, - default="./checkpoints", - help="models are saved here", - ) ## note: to add this param when using philly - # self.parser.add_argument('--project_dir', type=str, default='./', help='the project is saved here') ################### This is necessary for philly - self.parser.add_argument( - "--outputs_dir", type=str, default="./outputs", help="models are saved here" - ) ## note: to add this param when using philly Please end with '/' - self.parser.add_argument( - "--model", type=str, default="pix2pixHD", help="which model to use" - ) - self.parser.add_argument( - "--norm", - type=str, - default="instance", - help="instance normalization or batch normalization", - ) - self.parser.add_argument( - "--use_dropout", action="store_true", help="use dropout for the generator" - ) - self.parser.add_argument( - "--data_type", - default=32, - type=int, - choices=[8, 16, 32], - help="Supported data type i.e. 8, 16, 32 bit", - ) - self.parser.add_argument( - "--verbose", action="store_true", default=False, help="toggles verbose" - ) - - # input/output sizes - self.parser.add_argument( - "--batchSize", type=int, default=1, help="input batch size" - ) - self.parser.add_argument( - "--loadSize", type=int, default=1024, help="scale images to this size" - ) - self.parser.add_argument( - "--fineSize", type=int, default=512, help="then crop to this size" - ) - self.parser.add_argument( - "--label_nc", type=int, default=35, help="# of input label channels" - ) - self.parser.add_argument( - "--input_nc", type=int, default=3, help="# of input image channels" - ) - self.parser.add_argument( - "--output_nc", type=int, default=3, help="# of output image channels" - ) - - # for setting inputs - self.parser.add_argument( - "--dataroot", type=str, default="./datasets/cityscapes/" - ) - self.parser.add_argument( - "--resize_or_crop", - type=str, - default="scale_width", - help="scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]", - ) - self.parser.add_argument( - "--serial_batches", - action="store_true", - help="if true, takes images in order to make batches, otherwise takes them randomly", - ) - self.parser.add_argument( - "--no_flip", - action="store_true", - help="if specified, do not flip the images for data argumentation", - ) - self.parser.add_argument( - "--nThreads", default=2, type=int, help="# threads for loading data" - ) - self.parser.add_argument( - "--max_dataset_size", - type=int, - default=float("inf"), - help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", - ) - - # for displays - self.parser.add_argument( - "--display_winsize", type=int, default=512, help="display window size" - ) - self.parser.add_argument( - "--tf_log", - action="store_true", - help="if specified, use tensorboard logging. Requires tensorflow installed", - ) - - # for generator - self.parser.add_argument( - "--netG", type=str, default="global", help="selects model to use for netG" - ) - self.parser.add_argument( - "--ngf", type=int, default=64, help="# of gen filters in first conv layer" - ) - self.parser.add_argument( - "--k_size", type=int, default=3, help="# kernel size conv layer" - ) - self.parser.add_argument("--use_v2", action="store_true", help="use DCDCv2") - self.parser.add_argument("--mc", type=int, default=1024, help="# max channel") - self.parser.add_argument( - "--start_r", type=int, default=3, help="start layer to use resblock" - ) - self.parser.add_argument( - "--n_downsample_global", - type=int, - default=4, - help="number of downsampling layers in netG", - ) - self.parser.add_argument( - "--n_blocks_global", - type=int, - default=9, - help="number of residual blocks in the global generator network", - ) - self.parser.add_argument( - "--n_blocks_local", - type=int, - default=3, - help="number of residual blocks in the local enhancer network", - ) - self.parser.add_argument( - "--n_local_enhancers", - type=int, - default=1, - help="number of local enhancers to use", - ) - self.parser.add_argument( - "--niter_fix_global", - type=int, - default=0, - help="number of epochs that we only train the outmost local enhancer", - ) - - self.parser.add_argument( - "--load_pretrain", - type=str, - default="", - help="load the pretrained model from the specified location", - ) - - # for instance-wise features - self.parser.add_argument( - "--no_instance", - action="store_true", - help="if specified, do *not* add instance map as input", - ) - self.parser.add_argument( - "--instance_feat", - action="store_true", - help="if specified, add encoded instance features as input", - ) - self.parser.add_argument( - "--label_feat", - action="store_true", - help="if specified, add encoded label features as input", - ) - self.parser.add_argument( - "--feat_num", type=int, default=3, help="vector length for encoded features" - ) - self.parser.add_argument( - "--load_features", - action="store_true", - help="if specified, load precomputed feature maps", - ) - self.parser.add_argument( - "--n_downsample_E", - type=int, - default=4, - help="# of downsampling layers in encoder", - ) - self.parser.add_argument( - "--nef", - type=int, - default=16, - help="# of encoder filters in the first conv layer", - ) - self.parser.add_argument( - "--n_clusters", type=int, default=10, help="number of clusters for features" - ) - - # diy - self.parser.add_argument( - "--self_gen", action="store_true", help="self generate" - ) - self.parser.add_argument( - "--mapping_n_block", - type=int, - default=3, - help="number of resblock in mapping", - ) - self.parser.add_argument( - "--map_mc", type=int, default=64, help="max channel of mapping" - ) - self.parser.add_argument("--kl", type=float, default=0, help="KL Loss") - self.parser.add_argument( - "--load_pretrainA", - type=str, - default="", - help="load the pretrained model from the specified location", - ) - self.parser.add_argument( - "--load_pretrainB", - type=str, - default="", - help="load the pretrained model from the specified location", - ) - self.parser.add_argument("--feat_gan", action="store_true") - self.parser.add_argument("--no_cgan", action="store_true") - self.parser.add_argument("--map_unet", action="store_true") - self.parser.add_argument("--map_densenet", action="store_true") - self.parser.add_argument("--fcn", action="store_true") - self.parser.add_argument( - "--is_image", action="store_true", help="train image recon only pair data" - ) - self.parser.add_argument("--label_unpair", action="store_true") - self.parser.add_argument("--mapping_unpair", action="store_true") - self.parser.add_argument("--unpair_w", type=float, default=1.0) - self.parser.add_argument("--pair_num", type=int, default=-1) - self.parser.add_argument("--Gan_w", type=float, default=1) - self.parser.add_argument("--feat_dim", type=int, default=-1) - self.parser.add_argument("--abalation_vae_len", type=int, default=-1) - - ######################### useless, just to cooperate with docker - self.parser.add_argument("--gpu", type=str) - self.parser.add_argument("--dataDir", type=str) - self.parser.add_argument("--modelDir", type=str) - self.parser.add_argument("--logDir", type=str) - self.parser.add_argument("--data_dir", type=str) - - self.parser.add_argument("--use_skip_model", action="store_true") - self.parser.add_argument("--use_segmentation_model", action="store_true") - - self.parser.add_argument("--spatio_size", type=int, default=64) - self.parser.add_argument("--test_random_crop", action="store_true") - ########################## - - self.parser.add_argument("--contain_scratch_L", action="store_true") - self.parser.add_argument( - "--mask_dilation", type=int, default=0 - ) ## Don't change the input, only dilation the mask - - self.parser.add_argument( - "--irregular_mask", - type=str, - default="", - help="This is the root of the mask", - ) - self.parser.add_argument( - "--mapping_net_dilation", - type=int, - default=1, - help="This parameter is the dilation size of the translation net", - ) - - self.parser.add_argument( - "--VOC", - type=str, - default="VOC_RGB_JPEGImages.bigfile", - help="The root of VOC dataset", - ) - - self.parser.add_argument( - "--non_local", type=str, default="", help="which non_local setting" - ) - self.parser.add_argument( - "--NL_fusion_method", - type=str, - default="add", - help="how to fuse the origin feature and nl feature", - ) - self.parser.add_argument( - "--NL_use_mask", - action="store_true", - help="If use mask while using Non-local mapping model", - ) - self.parser.add_argument( - "--correlation_renormalize", - action="store_true", - help="Since after mask out the correlation matrix(which is softmaxed), the sum is not 1 any more, enable this param to re-weight", - ) - - self.parser.add_argument( - "--Smooth_L1", action="store_true", help="Use L1 Loss in image level" - ) - - self.parser.add_argument( - "--face_restore_setting", - type=int, - default=1, - help="This is for the aligned face restoration", - ) - self.parser.add_argument("--face_clean_url", type=str, default="") - self.parser.add_argument("--syn_input_url", type=str, default="") - self.parser.add_argument("--syn_gt_url", type=str, default="") - - self.parser.add_argument( - "--test_on_synthetic", - action="store_true", - help="If you want to test on the synthetic data, enable this parameter", - ) - - self.parser.add_argument( - "--use_SN", action="store_true", help="Add SN to every parametric layer" - ) - - self.parser.add_argument( - "--use_two_stage_mapping", - action="store_true", - help="choose the model which uses two stage", - ) - - self.parser.add_argument("--L1_weight", type=float, default=10.0) - self.parser.add_argument("--softmax_temperature", type=float, default=1.0) - self.parser.add_argument( - "--patch_similarity", - action="store_true", - help="Enable this denotes using 3*3 patch to calculate similarity", - ) - self.parser.add_argument( - "--use_self", - action="store_true", - help="Enable this denotes that while constructing the new feature maps, using original feature (diagonal == 1)", - ) - - self.parser.add_argument("--use_own_dataset", action="store_true") - - self.parser.add_argument( - "--test_hole_two_folders", - action="store_true", - help="Enable this parameter means test the restoration with inpainting given twp folders which are mask and old respectively", - ) - - self.parser.add_argument( - "--no_hole", - action="store_true", - help="While test the full_model on non_scratch data, do not add random mask into the real old photos", - ) ## Only for testing - self.parser.add_argument( - "--random_hole", - action="store_true", - help="While training the full model, 50% probability add hole", - ) - - self.parser.add_argument( - "--NL_res", action="store_true", help="NL+Resdual Block" - ) - - self.parser.add_argument( - "--image_L1", action="store_true", help="Image level loss: L1" - ) - self.parser.add_argument( - "--hole_image_no_mask", - action="store_true", - help="while testing, give hole image but not give the mask", - ) - - self.parser.add_argument( - "--down_sample_degradation", - action="store_true", - help="down_sample the image only, corresponds to [down_sample_face]", - ) - - self.parser.add_argument( - "--norm_G", - type=str, - default="spectralinstance", - help="The norm type of Generator", - ) - self.parser.add_argument( - "--init_G", - type=str, - default="xavier", - help="normal|xavier|xavier_uniform|kaiming|orthogonal|none", - ) - - self.parser.add_argument("--use_new_G", action="store_true") - self.parser.add_argument("--use_new_D", action="store_true") - - self.parser.add_argument( - "--only_voc", - action="store_true", - help="test the trianed celebA face model using VOC face", - ) - - self.parser.add_argument( - "--cosin_similarity", - action="store_true", - help="For non-local, using cosin to calculate the similarity", - ) - - self.parser.add_argument( - "--downsample_mode", - type=str, - default="nearest", - help="For partial non-local, choose how to downsample the mask", - ) - - self.parser.add_argument( - "--mapping_exp", - type=int, - default=0, - help="Default 0: original PNL|1: Multi-Scale Patch Attention", - ) - self.parser.add_argument( - "--inference_optimize", action="store_true", help="optimize the memory cost" - ) - - self.initialized = True - - def parse(self, custom_args: list): - - assert len(custom_args) > 0, "Manually Pass Arguments!" - - if not self.initialized: - self.initialize() - - self.opt = self.parser.parse_args(custom_args) - self.opt.isTrain = self.isTrain # train or test - - str_ids = self.opt.gpu_ids.split(",") - self.opt.gpu_ids = [] - for str_id in str_ids: - int_id = int(str_id) - if int_id >= 0: - self.opt.gpu_ids.append(int_id) - - # set gpu ids - if torch.cuda.is_available() and len(self.opt.gpu_ids) > 0: - if torch.cuda.device_count() > self.opt.gpu_ids[0]: - try: - torch.cuda.set_device(self.opt.gpu_ids[0]) - except: - print("Failed to set GPU device. Using CPU...") - - else: - print("Invalid GPU ID. Using CPU...") - - return self.opt diff --git a/Global/options/test_options.py b/Global/options/test_options.py deleted file mode 100644 index 67e2e3a..0000000 --- a/Global/options/test_options.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .base_options import BaseOptions - - -class TestOptions(BaseOptions): - def initialize(self): - BaseOptions.initialize(self) - self.parser.add_argument("--ntest", type=int, default=float("inf"), help="# of test examples.") - self.parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.") - self.parser.add_argument( - "--aspect_ratio", type=float, default=1.0, help="aspect ratio of result images" - ) - self.parser.add_argument("--phase", type=str, default="test", help="train, val, test, etc") - self.parser.add_argument( - "--which_epoch", - type=str, - default="latest", - help="which epoch to load? set to latest to use latest cached model", - ) - self.parser.add_argument("--how_many", type=int, default=50, help="how many test images to run") - self.parser.add_argument( - "--cluster_path", - type=str, - default="features_clustered_010.npy", - help="the path for clustered results of encoded features", - ) - self.parser.add_argument( - "--use_encoded_image", - action="store_true", - help="if specified, encode the real image to get the feature map", - ) - self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") - self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") - self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") - self.parser.add_argument( - "--start_epoch", - type=int, - default=-1, - help="write the start_epoch of iter.txt into this parameter", - ) - - self.parser.add_argument("--test_dataset", type=str, default="Real_RGB_old.bigfile") - self.parser.add_argument( - "--no_degradation", - action="store_true", - help="when train the mapping, enable this parameter --> no degradation will be added into clean image", - ) - self.parser.add_argument( - "--no_load_VAE", - action="store_true", - help="when train the mapping, enable this parameter --> random initialize the encoder an decoder", - ) - self.parser.add_argument( - "--use_v2_degradation", - action="store_true", - help="enable this parameter --> 4 kinds of degradations will be used to synthesize corruption", - ) - self.parser.add_argument("--use_vae_which_epoch", type=str, default="latest") - self.isTrain = False - - self.parser.add_argument("--generate_pair", action="store_true") - - self.parser.add_argument("--multi_scale_test", type=float, default=0.5) - self.parser.add_argument("--multi_scale_threshold", type=float, default=0.5) - self.parser.add_argument( - "--mask_need_scale", - action="store_true", - help="enable this param meas that the pixel range of mask is 0-255", - ) - self.parser.add_argument("--scale_num", type=int, default=1) - - self.parser.add_argument( - "--save_feature_url", type=str, default="", help="While extracting the features, where to put" - ) - - self.parser.add_argument( - "--test_input", type=str, default="", help="A directory or a root of bigfile" - ) - self.parser.add_argument("--test_mask", type=str, default="", help="A directory or a root of bigfile") - self.parser.add_argument("--test_gt", type=str, default="", help="A directory or a root of bigfile") - - self.parser.add_argument( - "--scale_input", action="store_true", help="While testing, choose to scale the input firstly" - ) - - self.parser.add_argument( - "--save_feature_name", type=str, default="features.json", help="The name of saved features" - ) - self.parser.add_argument( - "--test_rgb_old_wo_scratch", action="store_true", help="Same setting with origin test" - ) - - self.parser.add_argument("--test_mode", type=str, default="Crop", help="Scale|Full|Crop") - self.parser.add_argument("--Quality_restore", action="store_true", help="For RGB images") - self.parser.add_argument( - "--Scratch_and_Quality_restore", action="store_true", help="For scratched images" - ) - self.parser.add_argument("--HR", action='store_true',help='Large input size with scratches') diff --git a/Global/test.py b/Global/test.py deleted file mode 100644 index 18f1e84..0000000 --- a/Global/test.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) Microsoft Corporation - -import torchvision.transforms as transforms -from PIL import Image -import numpy as np -import torch -import cv2 -import os - -from .models.mapping_model import Pix2PixHDModel_Mapping -from .options.test_options import TestOptions - -tensor2image = transforms.ToPILImage() - - -def data_transforms(img, method=Image.BILINEAR, scale=False): - - ow, oh = img.size - pw, ph = ow, oh - if scale == True: - if ow < oh: - ow = 256 - oh = ph / pw * 256 - else: - oh = 256 - ow = pw / ph * 256 - - h = int(round(oh / 4) * 4) - w = int(round(ow / 4) * 4) - - if (h == ph) and (w == pw): - return img - - return img.resize((w, h), method) - - -def data_transforms_rgb_old(img): - w, h = img.size - A = img - if w < 256 or h < 256: - A = transforms.Scale(256, Image.BILINEAR)(img) - return transforms.CenterCrop(256)(A) - - -def irregular_hole_synthesize(img, mask): - - img_np = np.array(img).astype("uint8") - mask_np = np.array(mask).astype("uint8") - mask_np = mask_np / 255 - img_new = img_np * (1 - mask_np) + mask_np * 255 - - hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB") - - return hole_img - - -def parameter_set(opt, ckpt_dir): - ## Default parameters - opt.serial_batches = True # no shuffle - opt.no_flip = True # no flip - opt.label_nc = 0 - opt.n_downsample_global = 3 - opt.mc = 64 - opt.k_size = 4 - opt.start_r = 1 - opt.mapping_n_block = 6 - opt.map_mc = 512 - opt.no_instance = True - opt.checkpoints_dir = ckpt_dir - - if opt.Quality_restore: - opt.name = "mapping_quality" - opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") - opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality") - if opt.Scratch_and_Quality_restore: - opt.NL_res = True - opt.use_SN = True - opt.correlation_renormalize = True - opt.NL_use_mask = True - opt.NL_fusion_method = "combine" - opt.non_local = "Setting_42" - opt.name = "mapping_scratch" - opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") - opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") - if opt.HR: - opt.mapping_exp = 1 - opt.inference_optimize = True - opt.mask_dilation = 3 - opt.name = "mapping_Patch_Attention" - - -def global_test( - ckpt_dir: str, custom_args: list, input_image: Image, mask: Image = None -) -> Image: - - opt = TestOptions().parse(custom_args) - - parameter_set(opt, ckpt_dir) - - model = Pix2PixHDModel_Mapping() - - model.initialize(opt) - model.eval() - - img_transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) - mask_transform = transforms.ToTensor() - - print("processing...") - - if opt.NL_use_mask: - if opt.mask_dilation != 0: - kernel = np.ones((3, 3), np.uint8) - mask = np.array(mask) - mask = cv2.dilate(mask, kernel, iterations=opt.mask_dilation) - mask = Image.fromarray(mask.astype("uint8")) - - input_image = irregular_hole_synthesize(input_image, mask) - mask = mask_transform(mask) - mask = mask[:1, :, :] ## Convert to single channel - mask = mask.unsqueeze(0) - input_image = img_transform(input_image) - input_image = input_image.unsqueeze(0) - - else: - if opt.test_mode == "Scale": - input_image = data_transforms(input_image, scale=True) - elif opt.test_mode == "Full": - input_image = data_transforms(input_image, scale=False) - elif opt.test_mode == "Crop": - input_image = data_transforms_rgb_old(input_image) - - input_image = img_transform(input_image) - input_image = input_image.unsqueeze(0) - mask = torch.zeros_like(input_image) - - with torch.no_grad(): - generated = model.inference(input_image, mask) - - restored = torch.clamp((generated.data.cpu() + 1.0) / 2.0, 0.0, 1.0) * 255 - return tensor2image(restored[0].byte()) diff --git a/Global/util/__init__.py b/Global/util/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/Global/util/image_pool.py b/Global/util/image_pool.py deleted file mode 100644 index 1e7846e..0000000 --- a/Global/util/image_pool.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import random -import torch -from torch.autograd import Variable - - -class ImagePool: - def __init__(self, pool_size): - self.pool_size = pool_size - if self.pool_size > 0: - self.num_imgs = 0 - self.images = [] - - def query(self, images): - if self.pool_size == 0: - return images - return_images = [] - for image in images.data: - image = torch.unsqueeze(image, 0) - if self.num_imgs < self.pool_size: - self.num_imgs = self.num_imgs + 1 - self.images.append(image) - return_images.append(image) - else: - p = random.uniform(0, 1) - if p > 0.5: - random_id = random.randint(0, self.pool_size - 1) - tmp = self.images[random_id].clone() - self.images[random_id] = image - return_images.append(tmp) - else: - return_images.append(image) - return_images = Variable(torch.cat(return_images, 0)) - return return_images diff --git a/Global/util/util.py b/Global/util/util.py deleted file mode 100644 index b1369c3..0000000 --- a/Global/util/util.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from __future__ import print_function -import torch -import numpy as np -from PIL import Image -import numpy as np -import os -import torch.nn as nn - -# Converts a Tensor into a Numpy array -# |imtype|: the desired type of the converted numpy array -def tensor2im(image_tensor, imtype=np.uint8, normalize=True): - if isinstance(image_tensor, list): - image_numpy = [] - for i in range(len(image_tensor)): - image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) - return image_numpy - image_numpy = image_tensor.cpu().float().numpy() - if normalize: - image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 - else: - image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 - image_numpy = np.clip(image_numpy, 0, 255) - if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: - image_numpy = image_numpy[:, :, 0] - return image_numpy.astype(imtype) - - -# Converts a one-hot tensor into a colorful label map -def tensor2label(label_tensor, n_label, imtype=np.uint8): - if n_label == 0: - return tensor2im(label_tensor, imtype) - label_tensor = label_tensor.cpu().float() - if label_tensor.size()[0] > 1: - label_tensor = label_tensor.max(0, keepdim=True)[1] - label_tensor = Colorize(n_label)(label_tensor) - label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) - return label_numpy.astype(imtype) - - -def save_image(image_numpy, image_path): - image_pil = Image.fromarray(image_numpy) - image_pil.save(image_path) - - -def mkdirs(paths): - if isinstance(paths, list) and not isinstance(paths, str): - for path in paths: - mkdir(path) - else: - mkdir(paths) - - -def mkdir(path): - if not os.path.exists(path): - os.makedirs(path) diff --git a/LICENSE_MICROSOFT b/LICENSE_MICROSOFT deleted file mode 100644 index 9e841e7..0000000 --- a/LICENSE_MICROSOFT +++ /dev/null @@ -1,21 +0,0 @@ - MIT License - - Copyright (c) Microsoft Corporation. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE diff --git a/README.md b/README.md index c489534..83b1591 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,11 @@ This is an Extension for the [Automatic1111 Webui](https://github.com/AUTOMATIC1 ## Requirements 0. Install this Extension 1. Download `global_checkpoints.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases) -2. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/Global` +2. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Global` 3. Download `face_checkpoints.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases) -4. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/Face_Enhancement` +4. Extract and put the `checkpoints` **folder** *(not just the files)* into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Face_Enhancement` 5. Download `shape_predictor_68_face_landmarks.zip` from [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases) -6. Extract the `.dat` **file** into `~webui/extensions/sd-webui-old-photo-restoration/Face_Detection` +6. Extract the `.dat` **file** into `~webui/extensions/sd-webui-old-photo-restoration/lib_bopb2l/Face_Detection` > The [Releases](https://github.com/Haoming02/sd-webui-old-photo-restoration/releases) page includes the original links, as well as the backups mirrored by myself @@ -40,3 +40,4 @@ After installing this Extension, there will be an **Old Photo Restoration** sect - *(This is **different** from the Webui built-in ones)* - **High Resolution:** Use higher parameters to do the processing - *(Only has an effect when either `Process Scratch` or `Face Restore` is also enabled)* +- **Use CPU:** Enable this if you do not have a Nvidia GPU or are getting **Out of Memory** Error diff --git a/lib_bopb2l b/lib_bopb2l new file mode 160000 index 0000000..b5b24d5 --- /dev/null +++ b/lib_bopb2l @@ -0,0 +1 @@ +Subproject commit b5b24d547512fca971a065f2fc3ef093d678701d diff --git a/preload.py b/preload.py index 84d04a8..9a9846f 100644 --- a/preload.py +++ b/preload.py @@ -2,14 +2,28 @@ import os EXTENSION_FOLDER = os.path.dirname(os.path.realpath(__file__)) -FACE = os.path.join(EXTENSION_FOLDER, "Face_Enhancement", "checkpoints") +FACE = os.path.join( + EXTENSION_FOLDER, + "lib_bopb2l", + "Face_Enhancement", + "checkpoints", +) +GLOBAL = os.path.join( + EXTENSION_FOLDER, + "lib_bopb2l", + "Global", + "checkpoints", +) +MDL = os.path.join( + EXTENSION_FOLDER, + "lib_bopb2l", + "Face_Detection", + "shape_predictor_68_face_landmarks.dat", +) + if not os.path.exists(FACE): print("\n[Warning] face_checkpoints not detected! Please download it from Release!") - -GLOBAL = os.path.join(EXTENSION_FOLDER, "Global", "checkpoints") if not os.path.exists(GLOBAL): print("[Warning] global_checkpoints not detected! Please download it from Release!") - -MDL = os.path.join(EXTENSION_FOLDER, "Face_Detection", "shape_predictor_68_face_landmarks.dat") if not os.path.exists(MDL): print("[Warning] face_landmarks not detected! Please download it from Release!\n") diff --git a/requirements.txt b/requirements.txt index 618f16b..517387c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -scikit-image -easydict PyYAML -dominate dill -tensorboardX -scipy -opencv-python +dominate +easydict einops matplotlib +opencv-python +scikit-image +scipy +tensorboardX diff --git a/scripts/bop_api.py b/scripts/bop_api.py index 54a2282..c1f0ac7 100644 --- a/scripts/bop_api.py +++ b/scripts/bop_api.py @@ -5,12 +5,12 @@ from modules.api import api from fastapi import FastAPI, Body import gradio as gr -from scripts.main_function import main +from scripts.bopb2l_main import main def bop_api(_: gr.Blocks, app: FastAPI): - @app.post("/bop") + @app.post("/bopb2l/restore") async def old_photo_restoration( image: str = Body("", title="input image"), scratch: bool = Body(False, title="process scratch"), @@ -22,9 +22,7 @@ def bop_api(_: gr.Blocks, app: FastAPI): return input_image = api.decode_base64_to_image(image) - img = main(input_image, scratch, hr, face_res, cpu) - return {"image": api.encode_pil_to_base64(img).decode("utf-8")} diff --git a/scripts/bop.py b/scripts/bopb2l.py similarity index 76% rename from scripts/bop.py rename to scripts/bopb2l.py index fba41f2..568ac99 100644 --- a/scripts/bop.py +++ b/scripts/bopb2l.py @@ -1,7 +1,8 @@ -from modules import scripts_postprocessing, ui_components +from modules.ui_components import InputAccordion +from modules import scripts_postprocessing import gradio as gr -from scripts.main_function import main +from scripts.bopb2l_main import main class OldPhotoRestoration(scripts_postprocessing.ScriptPostprocessing): @@ -9,22 +10,18 @@ class OldPhotoRestoration(scripts_postprocessing.ScriptPostprocessing): order = 200409484 def ui(self): - with ui_components.InputAccordion( - False, label="Old Photo Restoration" - ) as enable: - + with InputAccordion(False, label="Old Photo Restoration") as enable: proc_order = gr.Radio( - ["Restoration First", "Upscale First"], - label="Processing Order", + choices=("Restoration First", "Upscale First"), value="Restoration First", + label="Processing Order", ) with gr.Row(): - do_scratch = gr.Checkbox(label="Process Scratch") - do_face_res = gr.Checkbox(label="Face Restore") - + do_scratch = gr.Checkbox(False, label="Process Scratch") + do_face_res = gr.Checkbox(False, label="Face Restore") with gr.Row(): - is_hr = gr.Checkbox(label="High Resolution") + is_hr = gr.Checkbox(False, label="High Resolution") use_cpu = gr.Checkbox(True, label="Use CPU") args = { diff --git a/scripts/main_function.py b/scripts/bopb2l_main.py similarity index 83% rename from scripts/main_function.py rename to scripts/bopb2l_main.py index 0d84b20..b235b7f 100644 --- a/scripts/main_function.py +++ b/scripts/bopb2l_main.py @@ -1,12 +1,12 @@ -from Global.test import global_test -from Global.detection import global_detection +from lib_bopb2l.Global.test import global_test +from lib_bopb2l.Global.detection import global_detection -from Face_Detection.detect_all_dlib import detect -from Face_Detection.detect_all_dlib_HR import detect_hr -from Face_Detection.align_warp_back_multiple_dlib import align_warp -from Face_Detection.align_warp_back_multiple_dlib_HR import align_warp_hr +from lib_bopb2l.Face_Detection.detect_all_dlib import detect +from lib_bopb2l.Face_Detection.detect_all_dlib_HR import detect_hr +from lib_bopb2l.Face_Detection.align_warp_back_multiple_dlib import align_warp +from lib_bopb2l.Face_Detection.align_warp_back_multiple_dlib_HR import align_warp_hr -from Face_Enhancement.test_face import test_face +from lib_bopb2l.Face_Enhancement.test_face import test_face from modules import scripts from PIL import Image @@ -15,11 +15,12 @@ import os GLOBAL_CHECKPOINTS_FOLDER = os.path.join( - scripts.basedir(), "Global", "checkpoints", "restoration" + scripts.basedir(), "lib_bopb2l", "Global", "checkpoints", "restoration" ) FACE_CHECKPOINTS_FOLDER = os.path.join( - scripts.basedir(), "Face_Enhancement", "checkpoints" + scripts.basedir(), "lib_bopb2l", "Face_Enhancement", "checkpoints" ) + FACE_ENHANCEMENT_CHECKPOINTS = ("Setting_9_epoch_100", "FaceSR_512")