initial commit

This commit is contained in:
2025-10-21 13:37:07 +07:00
commit 9cd16e276a
1574 changed files with 2675557 additions and 0 deletions

View File

@@ -0,0 +1,332 @@
import math
import random
import hashlib
import logging
from enum import Enum
import cv2
import numpy as np
# from annotator.lama.saicinpainting.evaluation.masks.mask import SegmentationMask
from annotator.lama.saicinpainting.utils import LinearRamp
LOGGER = logging.getLogger(__name__)
class DrawMethod(Enum):
LINE = 'line'
CIRCLE = 'circle'
SQUARE = 'square'
def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
draw_method=DrawMethod.LINE):
draw_method = DrawMethod(draw_method)
height, width = shape
mask = np.zeros((height, width), np.float32)
times = np.random.randint(min_times, max_times + 1)
for i in range(times):
start_x = np.random.randint(width)
start_y = np.random.randint(height)
for j in range(1 + np.random.randint(5)):
angle = 0.01 + np.random.randint(max_angle)
if i % 2 == 0:
angle = 2 * 3.1415926 - angle
length = 10 + np.random.randint(max_len)
brush_w = 5 + np.random.randint(max_width)
end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
if draw_method == DrawMethod.LINE:
cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
elif draw_method == DrawMethod.CIRCLE:
cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
elif draw_method == DrawMethod.SQUARE:
radius = brush_w // 2
mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
start_x, start_y = end_x, end_y
return mask[None, ...]
class RandomIrregularMaskGenerator:
def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
draw_method=DrawMethod.LINE):
self.max_angle = max_angle
self.max_len = max_len
self.max_width = max_width
self.min_times = min_times
self.max_times = max_times
self.draw_method = draw_method
self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
def __call__(self, img, iter_i=None, raw_image=None):
coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
cur_max_len = int(max(1, self.max_len * coef))
cur_max_width = int(max(1, self.max_width * coef))
cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
draw_method=self.draw_method)
def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
height, width = shape
mask = np.zeros((height, width), np.float32)
bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
times = np.random.randint(min_times, max_times + 1)
for i in range(times):
box_width = np.random.randint(bbox_min_size, bbox_max_size)
box_height = np.random.randint(bbox_min_size, bbox_max_size)
start_x = np.random.randint(margin, width - margin - box_width + 1)
start_y = np.random.randint(margin, height - margin - box_height + 1)
mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
return mask[None, ...]
class RandomRectangleMaskGenerator:
def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
self.margin = margin
self.bbox_min_size = bbox_min_size
self.bbox_max_size = bbox_max_size
self.min_times = min_times
self.max_times = max_times
self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
def __call__(self, img, iter_i=None, raw_image=None):
coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
max_times=cur_max_times)
class RandomSegmentationMaskGenerator:
def __init__(self, **kwargs):
self.impl = None # will be instantiated in first call (effectively in subprocess)
self.kwargs = kwargs
def __call__(self, img, iter_i=None, raw_image=None):
if self.impl is None:
self.impl = SegmentationMask(**self.kwargs)
masks = self.impl.get_masks(np.transpose(img, (1, 2, 0)))
masks = [m for m in masks if len(np.unique(m)) > 1]
return np.random.choice(masks)
def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
height, width = shape
mask = np.zeros((height, width), np.float32)
step_x = np.random.randint(min_step, max_step + 1)
width_x = np.random.randint(min_width, min(step_x, max_width + 1))
offset_x = np.random.randint(0, step_x)
step_y = np.random.randint(min_step, max_step + 1)
width_y = np.random.randint(min_width, min(step_y, max_width + 1))
offset_y = np.random.randint(0, step_y)
for dy in range(width_y):
mask[offset_y + dy::step_y] = 1
for dx in range(width_x):
mask[:, offset_x + dx::step_x] = 1
return mask[None, ...]
class RandomSuperresMaskGenerator:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __call__(self, img, iter_i=None):
return make_random_superres_mask(img.shape[1:], **self.kwargs)
class DumbAreaMaskGenerator:
min_ratio = 0.1
max_ratio = 0.35
default_ratio = 0.225
def __init__(self, is_training):
#Parameters:
# is_training(bool): If true - random rectangular mask, if false - central square mask
self.is_training = is_training
def _random_vector(self, dimension):
if self.is_training:
lower_limit = math.sqrt(self.min_ratio)
upper_limit = math.sqrt(self.max_ratio)
mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
u = random.randint(0, dimension-mask_side-1)
v = u+mask_side
else:
margin = (math.sqrt(self.default_ratio) / 2) * dimension
u = round(dimension/2 - margin)
v = round(dimension/2 + margin)
return u, v
def __call__(self, img, iter_i=None, raw_image=None):
c, height, width = img.shape
mask = np.zeros((height, width), np.float32)
x1, x2 = self._random_vector(width)
y1, y2 = self._random_vector(height)
mask[x1:x2, y1:y2] = 1
return mask[None, ...]
class OutpaintingMaskGenerator:
def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5,
right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False):
"""
is_fixed_randomness - get identical paddings for the same image if args are the same
"""
self.min_padding_percent = min_padding_percent
self.max_padding_percent = max_padding_percent
self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
self.is_fixed_randomness = is_fixed_randomness
assert self.min_padding_percent <= self.max_padding_percent
assert self.max_padding_percent > 0
assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
if len([x for x in self.probs if x > 0]) == 1:
LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
def apply_padding(self, mask, coord):
mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
return mask
def get_padding(self, size):
n1 = int(self.min_padding_percent*size)
n2 = int(self.max_padding_percent*size)
return self.rnd.randint(n1, n2) / size
@staticmethod
def _img2rs(img):
arr = np.ascontiguousarray(img.astype(np.uint8))
str_hash = hashlib.sha1(arr).hexdigest()
res = hash(str_hash)%(2**32)
return res
def __call__(self, img, iter_i=None, raw_image=None):
c, self.img_h, self.img_w = img.shape
mask = np.zeros((self.img_h, self.img_w), np.float32)
at_least_one_mask_applied = False
if self.is_fixed_randomness:
assert raw_image is not None, f"Cant calculate hash on raw_image=None"
rs = self._img2rs(raw_image)
self.rnd = np.random.RandomState(rs)
else:
self.rnd = np.random
coords = [[
(0,0),
(1,self.get_padding(size=self.img_h))
],
[
(0,0),
(self.get_padding(size=self.img_w),1)
],
[
(0,1-self.get_padding(size=self.img_h)),
(1,1)
],
[
(1-self.get_padding(size=self.img_w),0),
(1,1)
]]
for pp, coord in zip(self.probs, coords):
if self.rnd.random() < pp:
at_least_one_mask_applied = True
mask = self.apply_padding(mask=mask, coord=coord)
if not at_least_one_mask_applied:
idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
mask = self.apply_padding(mask=mask, coord=coords[idx])
return mask[None, ...]
class MixedMaskGenerator:
def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
box_proba=1/3, box_kwargs=None,
segm_proba=1/3, segm_kwargs=None,
squares_proba=0, squares_kwargs=None,
superres_proba=0, superres_kwargs=None,
outpainting_proba=0, outpainting_kwargs=None,
invert_proba=0):
self.probas = []
self.gens = []
if irregular_proba > 0:
self.probas.append(irregular_proba)
if irregular_kwargs is None:
irregular_kwargs = {}
else:
irregular_kwargs = dict(irregular_kwargs)
irregular_kwargs['draw_method'] = DrawMethod.LINE
self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
if box_proba > 0:
self.probas.append(box_proba)
if box_kwargs is None:
box_kwargs = {}
self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
if segm_proba > 0:
self.probas.append(segm_proba)
if segm_kwargs is None:
segm_kwargs = {}
self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs))
if squares_proba > 0:
self.probas.append(squares_proba)
if squares_kwargs is None:
squares_kwargs = {}
else:
squares_kwargs = dict(squares_kwargs)
squares_kwargs['draw_method'] = DrawMethod.SQUARE
self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
if superres_proba > 0:
self.probas.append(superres_proba)
if superres_kwargs is None:
superres_kwargs = {}
self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
if outpainting_proba > 0:
self.probas.append(outpainting_proba)
if outpainting_kwargs is None:
outpainting_kwargs = {}
self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
self.probas = np.array(self.probas, dtype='float32')
self.probas /= self.probas.sum()
self.invert_proba = invert_proba
def __call__(self, img, iter_i=None, raw_image=None):
kind = np.random.choice(len(self.probas), p=self.probas)
gen = self.gens[kind]
result = gen(img, iter_i=iter_i, raw_image=raw_image)
if self.invert_proba > 0 and random.random() < self.invert_proba:
result = 1 - result
return result
def get_mask_generator(kind, kwargs):
if kind is None:
kind = "mixed"
if kwargs is None:
kwargs = {}
if kind == "mixed":
cl = MixedMaskGenerator
elif kind == "outpainting":
cl = OutpaintingMaskGenerator
elif kind == "dumb":
cl = DumbAreaMaskGenerator
else:
raise NotImplementedError(f"No such generator kind = {kind}")
return cl(**kwargs)

View File

@@ -0,0 +1,177 @@
from typing import Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseAdversarialLoss:
def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
"""
Prepare for generator step
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param generator:
:param discriminator:
:return: None
"""
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
"""
Prepare for discriminator step
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param generator:
:param discriminator:
:return: None
"""
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask: Optional[torch.Tensor] = None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Calculate generator loss
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param discr_real_pred: Tensor, discriminator output for real_batch
:param discr_fake_pred: Tensor, discriminator output for fake_batch
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
:return: total generator loss along with some values that might be interesting to log
"""
raise NotImplemented()
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask: Optional[torch.Tensor] = None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Calculate discriminator loss and call .backward() on it
:param real_batch: Tensor, a batch of real samples
:param fake_batch: Tensor, a batch of samples produced by generator
:param discr_real_pred: Tensor, discriminator output for real_batch
:param discr_fake_pred: Tensor, discriminator output for fake_batch
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
:return: total discriminator loss along with some values that might be interesting to log
"""
raise NotImplemented()
def interpolate_mask(self, mask, shape):
assert mask is not None
assert self.allow_scale_mask or shape == mask.shape[-2:]
if shape != mask.shape[-2:] and self.allow_scale_mask:
if self.mask_scale_mode == 'maxpool':
mask = F.adaptive_max_pool2d(mask, shape)
else:
mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
return mask
def make_r1_gp(discr_real_pred, real_batch):
if torch.is_grad_enabled():
grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
else:
grad_penalty = 0
real_batch.requires_grad = False
return grad_penalty
class NonSaturatingWithR1(BaseAdversarialLoss):
def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
use_unmasked_for_gen=True, use_unmasked_for_discr=True):
self.gp_coef = gp_coef
self.weight = weight
# use for discr => use for gen;
# otherwise we teach only the discr to pay attention to very small difference
assert use_unmasked_for_gen or (not use_unmasked_for_discr)
# mask as target => use unmasked for discr:
# if we don't care about unmasked regions at all
# then it doesn't matter if the value of mask_as_fake_target is true or false
assert use_unmasked_for_discr or (not mask_as_fake_target)
self.use_unmasked_for_gen = use_unmasked_for_gen
self.use_unmasked_for_discr = use_unmasked_for_discr
self.mask_as_fake_target = mask_as_fake_target
self.allow_scale_mask = allow_scale_mask
self.mask_scale_mode = mask_scale_mode
self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask=None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
fake_loss = F.softplus(-discr_fake_pred)
if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
not self.use_unmasked_for_gen: # == if masked region should be treated differently
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
if not self.use_unmasked_for_gen:
fake_loss = fake_loss * mask
else:
pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
fake_loss = fake_loss * pixel_weights
return fake_loss.mean() * self.weight, dict()
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
real_batch.requires_grad = True
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
mask=None) \
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_loss = F.softplus(-discr_real_pred)
grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
fake_loss = F.softplus(discr_fake_pred)
if not self.use_unmasked_for_discr or self.mask_as_fake_target:
# == if masked region should be treated differently
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
# use_unmasked_for_discr=False only makes sense for fakes;
# for reals there is no difference beetween two regions
fake_loss = fake_loss * mask
if self.mask_as_fake_target:
fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
sum_discr_loss = real_loss + grad_penalty + fake_loss
metrics = dict(discr_real_out=discr_real_pred.mean(),
discr_fake_out=discr_fake_pred.mean(),
discr_real_gp=grad_penalty)
return sum_discr_loss.mean(), metrics
class BCELoss(BaseAdversarialLoss):
def __init__(self, weight):
self.weight = weight
self.bce_loss = nn.BCEWithLogitsLoss()
def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
return fake_loss, dict()
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
generator: nn.Module, discriminator: nn.Module):
real_batch.requires_grad = True
def discriminator_loss(self,
mask: torch.Tensor,
discr_real_pred: torch.Tensor,
discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
metrics = dict(discr_real_out=discr_real_pred.mean(),
discr_fake_out=discr_fake_pred.mean(),
discr_real_gp=0)
return sum_discr_loss, metrics
def make_discrim_loss(kind, **kwargs):
if kind == 'r1':
return NonSaturatingWithR1(**kwargs)
elif kind == 'bce':
return BCELoss(**kwargs)
raise ValueError(f'Unknown adversarial loss kind {kind}')

View File

@@ -0,0 +1,152 @@
weights = {"ade20k":
[6.34517766497462,
9.328358208955224,
11.389521640091116,
16.10305958132045,
20.833333333333332,
22.22222222222222,
25.125628140703515,
43.29004329004329,
50.5050505050505,
54.6448087431694,
55.24861878453038,
60.24096385542168,
62.5,
66.2251655629139,
84.74576271186442,
90.90909090909092,
91.74311926605505,
96.15384615384616,
96.15384615384616,
97.08737864077669,
102.04081632653062,
135.13513513513513,
149.2537313432836,
153.84615384615384,
163.93442622950818,
166.66666666666666,
188.67924528301887,
192.30769230769232,
217.3913043478261,
227.27272727272725,
227.27272727272725,
227.27272727272725,
303.03030303030306,
322.5806451612903,
333.3333333333333,
370.3703703703703,
384.61538461538464,
416.6666666666667,
416.6666666666667,
434.7826086956522,
434.7826086956522,
454.5454545454545,
454.5454545454545,
500.0,
526.3157894736842,
526.3157894736842,
555.5555555555555,
555.5555555555555,
555.5555555555555,
555.5555555555555,
555.5555555555555,
555.5555555555555,
555.5555555555555,
588.2352941176471,
588.2352941176471,
588.2352941176471,
588.2352941176471,
588.2352941176471,
666.6666666666666,
666.6666666666666,
666.6666666666666,
666.6666666666666,
714.2857142857143,
714.2857142857143,
714.2857142857143,
714.2857142857143,
714.2857142857143,
769.2307692307693,
769.2307692307693,
769.2307692307693,
833.3333333333334,
833.3333333333334,
833.3333333333334,
833.3333333333334,
909.090909090909,
1000.0,
1111.111111111111,
1111.111111111111,
1111.111111111111,
1111.111111111111,
1111.111111111111,
1250.0,
1250.0,
1250.0,
1250.0,
1250.0,
1428.5714285714287,
1428.5714285714287,
1428.5714285714287,
1428.5714285714287,
1428.5714285714287,
1428.5714285714287,
1428.5714285714287,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
1666.6666666666667,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
2500.0,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
3333.3333333333335,
5000.0,
5000.0,
5000.0]
}

View File

@@ -0,0 +1,126 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from annotator.lama.saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN
def dummy_distance_weighter(real_img, pred_img, mask):
return mask
def get_gauss_kernel(kernel_size, width_factor=1):
coords = torch.stack(torch.meshgrid(torch.arange(kernel_size),
torch.arange(kernel_size)),
dim=0).float()
diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor)
diff /= diff.sum()
return diff
class BlurMask(nn.Module):
def __init__(self, kernel_size=5, width_factor=1):
super().__init__()
self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False)
self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor))
def forward(self, real_img, pred_img, mask):
with torch.no_grad():
result = self.filter(mask) * mask
return result
class EmulatedEDTMask(nn.Module):
def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1):
super().__init__()
self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate',
bias=False)
self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float))
self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False)
self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor))
def forward(self, real_img, pred_img, mask):
with torch.no_grad():
known_mask = 1 - mask
dilated_known_mask = (self.dilate_filter(known_mask) > 1).float()
result = self.blur_filter(1 - dilated_known_mask) * mask
return result
class PropagatePerceptualSim(nn.Module):
def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3):
super().__init__()
vgg = torchvision.models.vgg19(pretrained=True).features
vgg_avg_pooling = []
for weights in vgg.parameters():
weights.requires_grad = False
cur_level_i = 0
for module in vgg.modules():
if module.__class__.__name__ == 'Sequential':
continue
elif module.__class__.__name__ == 'MaxPool2d':
vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
else:
vgg_avg_pooling.append(module)
if module.__class__.__name__ == 'ReLU':
cur_level_i += 1
if cur_level_i == level:
break
self.features = nn.Sequential(*vgg_avg_pooling)
self.max_iters = max_iters
self.temperature = temperature
self.do_erode = erode_mask_size > 0
if self.do_erode:
self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False)
self.erode_mask.weight.data.fill_(1)
def forward(self, real_img, pred_img, mask):
with torch.no_grad():
real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img)
real_feats = self.features(real_img)
vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True)
/ self.temperature)
horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True)
/ self.temperature)
mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False)
if self.do_erode:
mask_scaled = (self.erode_mask(mask_scaled) > 1).float()
cur_knowness = 1 - mask_scaled
for iter_i in range(self.max_iters):
new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate')
new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate')
new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate')
new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate')
new_knowness = torch.stack([new_top_knowness, new_bottom_knowness,
new_left_knowness, new_right_knowness],
dim=0).max(0).values
cur_knowness = torch.max(cur_knowness, new_knowness)
cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear')
result = torch.min(mask, 1 - cur_knowness)
return result
def make_mask_distance_weighter(kind='none', **kwargs):
if kind == 'none':
return dummy_distance_weighter
if kind == 'blur':
return BlurMask(**kwargs)
if kind == 'edt':
return EmulatedEDTMask(**kwargs)
if kind == 'pps':
return PropagatePerceptualSim(**kwargs)
raise ValueError(f'Unknown mask distance weighter kind {kind}')

View File

@@ -0,0 +1,33 @@
from typing import List
import torch
import torch.nn.functional as F
def masked_l2_loss(pred, target, mask, weight_known, weight_missing):
per_pixel_l2 = F.mse_loss(pred, target, reduction='none')
pixel_weights = mask * weight_missing + (1 - mask) * weight_known
return (pixel_weights * per_pixel_l2).mean()
def masked_l1_loss(pred, target, mask, weight_known, weight_missing):
per_pixel_l1 = F.l1_loss(pred, target, reduction='none')
pixel_weights = mask * weight_missing + (1 - mask) * weight_known
return (pixel_weights * per_pixel_l1).mean()
def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None):
if mask is None:
res = torch.stack([F.mse_loss(fake_feat, target_feat)
for fake_feat, target_feat in zip(fake_features, target_features)]).mean()
else:
res = 0
norm = 0
for fake_feat, target_feat in zip(fake_features, target_features):
cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False)
error_weights = 1 - cur_mask
cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean()
res = res + cur_val
norm += 1
res = res / norm
return res

View File

@@ -0,0 +1,113 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
# from models.ade20k import ModelBuilder
from annotator.lama.saicinpainting.utils import check_and_warn_input_range
IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
class PerceptualLoss(nn.Module):
def __init__(self, normalize_inputs=True):
super(PerceptualLoss, self).__init__()
self.normalize_inputs = normalize_inputs
self.mean_ = IMAGENET_MEAN
self.std_ = IMAGENET_STD
vgg = torchvision.models.vgg19(pretrained=True).features
vgg_avg_pooling = []
for weights in vgg.parameters():
weights.requires_grad = False
for module in vgg.modules():
if module.__class__.__name__ == 'Sequential':
continue
elif module.__class__.__name__ == 'MaxPool2d':
vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
else:
vgg_avg_pooling.append(module)
self.vgg = nn.Sequential(*vgg_avg_pooling)
def do_normalize_inputs(self, x):
return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
def partial_losses(self, input, target, mask=None):
check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
# we expect input and target to be in [0, 1] range
losses = []
if self.normalize_inputs:
features_input = self.do_normalize_inputs(input)
features_target = self.do_normalize_inputs(target)
else:
features_input = input
features_target = target
for layer in self.vgg[:30]:
features_input = layer(features_input)
features_target = layer(features_target)
if layer.__class__.__name__ == 'ReLU':
loss = F.mse_loss(features_input, features_target, reduction='none')
if mask is not None:
cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
mode='bilinear', align_corners=False)
loss = loss * (1 - cur_mask)
loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
losses.append(loss)
return losses
def forward(self, input, target, mask=None):
losses = self.partial_losses(input, target, mask=mask)
return torch.stack(losses).sum(dim=0)
def get_global_features(self, input):
check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
if self.normalize_inputs:
features_input = self.do_normalize_inputs(input)
else:
features_input = input
features_input = self.vgg(features_input)
return features_input
class ResNetPL(nn.Module):
def __init__(self, weight=1,
weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
super().__init__()
self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
arch_encoder=arch_encoder,
arch_decoder='ppm_deepsup',
fc_dim=2048,
segmentation=segmentation)
self.impl.eval()
for w in self.impl.parameters():
w.requires_grad_(False)
self.weight = weight
def forward(self, pred, target):
pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
pred_feats = self.impl(pred, return_feature_maps=True)
target_feats = self.impl(target, return_feature_maps=True)
result = torch.stack([F.mse_loss(cur_pred, cur_target)
for cur_pred, cur_target
in zip(pred_feats, target_feats)]).sum() * self.weight
return result

View File

@@ -0,0 +1,43 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .constants import weights as constant_weights
class CrossEntropy2d(nn.Module):
def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs):
"""
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
"""
super(CrossEntropy2d, self).__init__()
self.reduction = reduction
self.ignore_label = ignore_label
self.weights = weights
if self.weights is not None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
def forward(self, predict, target):
"""
Args:
predict:(n, c, h, w)
target:(n, 1, h, w)
"""
target = target.long()
assert not target.requires_grad
assert predict.dim() == 4, "{0}".format(predict.size())
assert target.dim() == 4, "{0}".format(target.size())
assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
assert target.size(1) == 1, "{0}".format(target.size(1))
assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
target = target.squeeze(1)
n, c, h, w = predict.size()
target_mask = (target >= 0) * (target != self.ignore_label)
target = target[target_mask]
predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction)
return loss

View File

@@ -0,0 +1,155 @@
import torch
import torch.nn as nn
import torchvision.models as models
class PerceptualLoss(nn.Module):
r"""
Perceptual loss, VGG-based
https://arxiv.org/abs/1603.08155
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
"""
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
super(PerceptualLoss, self).__init__()
self.add_module('vgg', VGG19())
self.criterion = torch.nn.L1Loss()
self.weights = weights
def __call__(self, x, y):
# Compute features
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
content_loss = 0.0
content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
return content_loss
class VGG19(torch.nn.Module):
def __init__(self):
super(VGG19, self).__init__()
features = models.vgg19(pretrained=True).features
self.relu1_1 = torch.nn.Sequential()
self.relu1_2 = torch.nn.Sequential()
self.relu2_1 = torch.nn.Sequential()
self.relu2_2 = torch.nn.Sequential()
self.relu3_1 = torch.nn.Sequential()
self.relu3_2 = torch.nn.Sequential()
self.relu3_3 = torch.nn.Sequential()
self.relu3_4 = torch.nn.Sequential()
self.relu4_1 = torch.nn.Sequential()
self.relu4_2 = torch.nn.Sequential()
self.relu4_3 = torch.nn.Sequential()
self.relu4_4 = torch.nn.Sequential()
self.relu5_1 = torch.nn.Sequential()
self.relu5_2 = torch.nn.Sequential()
self.relu5_3 = torch.nn.Sequential()
self.relu5_4 = torch.nn.Sequential()
for x in range(2):
self.relu1_1.add_module(str(x), features[x])
for x in range(2, 4):
self.relu1_2.add_module(str(x), features[x])
for x in range(4, 7):
self.relu2_1.add_module(str(x), features[x])
for x in range(7, 9):
self.relu2_2.add_module(str(x), features[x])
for x in range(9, 12):
self.relu3_1.add_module(str(x), features[x])
for x in range(12, 14):
self.relu3_2.add_module(str(x), features[x])
for x in range(14, 16):
self.relu3_2.add_module(str(x), features[x])
for x in range(16, 18):
self.relu3_4.add_module(str(x), features[x])
for x in range(18, 21):
self.relu4_1.add_module(str(x), features[x])
for x in range(21, 23):
self.relu4_2.add_module(str(x), features[x])
for x in range(23, 25):
self.relu4_3.add_module(str(x), features[x])
for x in range(25, 27):
self.relu4_4.add_module(str(x), features[x])
for x in range(27, 30):
self.relu5_1.add_module(str(x), features[x])
for x in range(30, 32):
self.relu5_2.add_module(str(x), features[x])
for x in range(32, 34):
self.relu5_3.add_module(str(x), features[x])
for x in range(34, 36):
self.relu5_4.add_module(str(x), features[x])
# don't need the gradients, just want the features
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
relu1_1 = self.relu1_1(x)
relu1_2 = self.relu1_2(relu1_1)
relu2_1 = self.relu2_1(relu1_2)
relu2_2 = self.relu2_2(relu2_1)
relu3_1 = self.relu3_1(relu2_2)
relu3_2 = self.relu3_2(relu3_1)
relu3_3 = self.relu3_3(relu3_2)
relu3_4 = self.relu3_4(relu3_3)
relu4_1 = self.relu4_1(relu3_4)
relu4_2 = self.relu4_2(relu4_1)
relu4_3 = self.relu4_3(relu4_2)
relu4_4 = self.relu4_4(relu4_3)
relu5_1 = self.relu5_1(relu4_4)
relu5_2 = self.relu5_2(relu5_1)
relu5_3 = self.relu5_3(relu5_2)
relu5_4 = self.relu5_4(relu5_3)
out = {
'relu1_1': relu1_1,
'relu1_2': relu1_2,
'relu2_1': relu2_1,
'relu2_2': relu2_2,
'relu3_1': relu3_1,
'relu3_2': relu3_2,
'relu3_3': relu3_3,
'relu3_4': relu3_4,
'relu4_1': relu4_1,
'relu4_2': relu4_2,
'relu4_3': relu4_3,
'relu4_4': relu4_4,
'relu5_1': relu5_1,
'relu5_2': relu5_2,
'relu5_3': relu5_3,
'relu5_4': relu5_4,
}
return out

View File

@@ -0,0 +1,31 @@
import logging
from annotator.lama.saicinpainting.training.modules.ffc import FFCResNetGenerator
from annotator.lama.saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
NLayerDiscriminator, MultidilatedNLayerDiscriminator
def make_generator(config, kind, **kwargs):
logging.info(f'Make generator {kind}')
if kind == 'pix2pixhd_multidilated':
return MultiDilatedGlobalGenerator(**kwargs)
if kind == 'pix2pixhd_global':
return GlobalGenerator(**kwargs)
if kind == 'ffc_resnet':
return FFCResNetGenerator(**kwargs)
raise ValueError(f'Unknown generator kind {kind}')
def make_discriminator(kind, **kwargs):
logging.info(f'Make discriminator {kind}')
if kind == 'pix2pixhd_nlayer_multidilated':
return MultidilatedNLayerDiscriminator(**kwargs)
if kind == 'pix2pixhd_nlayer':
return NLayerDiscriminator(**kwargs)
raise ValueError(f'Unknown discriminator kind {kind}')

View File

@@ -0,0 +1,80 @@
import abc
from typing import Tuple, List
import torch
import torch.nn as nn
from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
class BaseDiscriminator(nn.Module):
@abc.abstractmethod
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Predict scores and get intermediate activations. Useful for feature matching loss
:return tuple (scores, list of intermediate activations)
"""
raise NotImplemented()
def get_conv_block_ctor(kind='default'):
if not isinstance(kind, str):
return kind
if kind == 'default':
return nn.Conv2d
if kind == 'depthwise':
return DepthWiseSeperableConv
if kind == 'multidilated':
return MultidilatedConv
raise ValueError(f'Unknown convolutional block kind {kind}')
def get_norm_layer(kind='bn'):
if not isinstance(kind, str):
return kind
if kind == 'bn':
return nn.BatchNorm2d
if kind == 'in':
return nn.InstanceNorm2d
raise ValueError(f'Unknown norm block kind {kind}')
def get_activation(kind='tanh'):
if kind == 'tanh':
return nn.Tanh()
if kind == 'sigmoid':
return nn.Sigmoid()
if kind is False:
return nn.Identity()
raise ValueError(f'Unknown activation kind {kind}')
class SimpleMultiStepGenerator(nn.Module):
def __init__(self, steps: List[nn.Module]):
super().__init__()
self.steps = nn.ModuleList(steps)
def forward(self, x):
cur_in = x
outs = []
for step in self.steps:
cur_out = step(cur_in)
outs.append(cur_out)
cur_in = torch.cat((cur_in, cur_out), dim=1)
return torch.cat(outs[::-1], dim=1)
def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
if kind == 'convtranspose':
return [nn.ConvTranspose2d(min(max_features, ngf * mult),
min(max_features, int(ngf * mult / 2)),
kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(min(max_features, int(ngf * mult / 2))), activation]
elif kind == 'bilinear':
return [nn.Upsample(scale_factor=2, mode='bilinear'),
DepthWiseSeperableConv(min(max_features, ngf * mult),
min(max_features, int(ngf * mult / 2)),
kernel_size=3, stride=1, padding=1),
norm_layer(min(max_features, int(ngf * mult / 2))), activation]
else:
raise Exception(f"Invalid deconv kind: {kind}")

View File

@@ -0,0 +1,17 @@
import torch
import torch.nn as nn
class DepthWiseSeperableConv(nn.Module):
def __init__(self, in_dim, out_dim, *args, **kwargs):
super().__init__()
if 'groups' in kwargs:
# ignoring groups for Depthwise Sep Conv
del kwargs['groups']
self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out

View File

@@ -0,0 +1,47 @@
import torch
from kornia import SamplePadding
from kornia.augmentation import RandomAffine, CenterCrop
class FakeFakesGenerator:
def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
self.grad_aug = RandomAffine(degrees=360,
translate=0.2,
padding_mode=SamplePadding.REFLECTION,
keepdim=False,
p=1)
self.img_aug = RandomAffine(degrees=img_aug_degree,
translate=img_aug_translate,
padding_mode=SamplePadding.REFLECTION,
keepdim=True,
p=1)
self.aug_proba = aug_proba
def __call__(self, input_images, masks):
blend_masks = self._fill_masks_with_gradient(masks)
blend_target = self._make_blend_target(input_images)
result = input_images * (1 - blend_masks) + blend_target * blend_masks
return result, blend_masks
def _make_blend_target(self, input_images):
batch_size = input_images.shape[0]
permuted = input_images[torch.randperm(batch_size)]
augmented = self.img_aug(input_images)
is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
result = augmented * is_aug + permuted * (1 - is_aug)
return result
def _fill_masks_with_gradient(self, masks):
batch_size, _, height, width = masks.shape
grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
.view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
grad = self.grad_aug(grad)
grad = CenterCrop((height, width))(grad)
grad *= masks
grad_for_min = grad + (1 - masks) * 10
grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
grad.clamp_(min=0, max=1)
return grad

View File

@@ -0,0 +1,485 @@
# Fast Fourier Convolution NeurIPS 2020
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from annotator.lama.saicinpainting.training.modules.base import get_activation, BaseDiscriminator
from annotator.lama.saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper
from annotator.lama.saicinpainting.training.modules.squeeze_excitation import SELayer
from annotator.lama.saicinpainting.utils import get_shape
class FFCSE_block(nn.Module):
def __init__(self, channels, ratio_g):
super(FFCSE_block, self).__init__()
in_cg = int(channels * ratio_g)
in_cl = channels - in_cg
r = 16
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.conv1 = nn.Conv2d(channels, channels // r,
kernel_size=1, bias=True)
self.relu1 = nn.ReLU(inplace=True)
self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
channels // r, in_cl, kernel_size=1, bias=True)
self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
channels // r, in_cg, kernel_size=1, bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = x if type(x) is tuple else (x, 0)
id_l, id_g = x
x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
x = self.avgpool(x)
x = self.relu1(self.conv1(x))
x_l = 0 if self.conv_a2l is None else id_l * \
self.sigmoid(self.conv_a2l(x))
x_g = 0 if self.conv_a2g is None else id_g * \
self.sigmoid(self.conv_a2g(x))
return x_l, x_g
class FourierUnit(nn.Module):
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
# bn_layer not used
super(FourierUnit, self).__init__()
self.groups = groups
self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
self.bn = torch.nn.BatchNorm2d(out_channels * 2)
self.relu = torch.nn.ReLU(inplace=True)
# squeeze and excitation block
self.use_se = use_se
if use_se:
if se_kwargs is None:
se_kwargs = {}
self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
self.spatial_scale_factor = spatial_scale_factor
self.spatial_scale_mode = spatial_scale_mode
self.spectral_pos_encoding = spectral_pos_encoding
self.ffc3d = ffc3d
self.fft_norm = fft_norm
def forward(self, x):
batch = x.shape[0]
if self.spatial_scale_factor is not None:
orig_size = x.shape[-2:]
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
r_size = x.size()
# (batch, c, h, w/2+1, 2)
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
if self.spectral_pos_encoding:
height, width = ffted.shape[-2:]
coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
if self.use_se:
ffted = self.se(ffted)
ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
ffted = self.relu(self.bn(ffted))
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
if self.spatial_scale_factor is not None:
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
return output
class SeparableFourierUnit(nn.Module):
def __init__(self, in_channels, out_channels, groups=1, kernel_size=3):
# bn_layer not used
super(SeparableFourierUnit, self).__init__()
self.groups = groups
row_out_channels = out_channels // 2
col_out_channels = out_channels - row_out_channels
self.row_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
out_channels=row_out_channels * 2,
kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
stride=1, padding=(kernel_size // 2, 0),
padding_mode='reflect',
groups=self.groups, bias=False)
self.col_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
out_channels=col_out_channels * 2,
kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
stride=1, padding=(kernel_size // 2, 0),
padding_mode='reflect',
groups=self.groups, bias=False)
self.row_bn = torch.nn.BatchNorm2d(row_out_channels * 2)
self.col_bn = torch.nn.BatchNorm2d(col_out_channels * 2)
self.relu = torch.nn.ReLU(inplace=True)
def process_branch(self, x, conv, bn):
batch = x.shape[0]
r_size = x.size()
# (batch, c, h, w/2+1, 2)
ffted = torch.fft.rfft(x, norm="ortho")
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
ffted = self.relu(bn(conv(ffted)))
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
output = torch.fft.irfft(ffted, s=x.shape[-1:], norm="ortho")
return output
def forward(self, x):
rowwise = self.process_branch(x, self.row_conv, self.row_bn)
colwise = self.process_branch(x.permute(0, 1, 3, 2), self.col_conv, self.col_bn).permute(0, 1, 3, 2)
out = torch.cat((rowwise, colwise), dim=1)
return out
class SpectralTransform(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, separable_fu=False, **fu_kwargs):
# bn_layer not used
super(SpectralTransform, self).__init__()
self.enable_lfu = enable_lfu
if stride == 2:
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
else:
self.downsample = nn.Identity()
self.stride = stride
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels //
2, kernel_size=1, groups=groups, bias=False),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU(inplace=True)
)
fu_class = SeparableFourierUnit if separable_fu else FourierUnit
self.fu = fu_class(
out_channels // 2, out_channels // 2, groups, **fu_kwargs)
if self.enable_lfu:
self.lfu = fu_class(
out_channels // 2, out_channels // 2, groups)
self.conv2 = torch.nn.Conv2d(
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
def forward(self, x):
x = self.downsample(x)
x = self.conv1(x)
output = self.fu(x)
if self.enable_lfu:
n, c, h, w = x.shape
split_no = 2
split_s = h // split_no
xs = torch.cat(torch.split(
x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
xs = torch.cat(torch.split(xs, split_s, dim=-1),
dim=1).contiguous()
xs = self.lfu(xs)
xs = xs.repeat(1, 1, split_no, split_no).contiguous()
else:
xs = 0
output = self.conv2(x + output + xs)
return output
class FFC(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
ratio_gin, ratio_gout, stride=1, padding=0,
dilation=1, groups=1, bias=False, enable_lfu=True,
padding_type='reflect', gated=False, **spectral_kwargs):
super(FFC, self).__init__()
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
self.stride = stride
in_cg = int(in_channels * ratio_gin)
in_cl = in_channels - in_cg
out_cg = int(out_channels * ratio_gout)
out_cl = out_channels - out_cg
#groups_g = 1 if groups == 1 else int(groups * ratio_gout)
#groups_l = 1 if groups == 1 else groups - groups_g
self.ratio_gin = ratio_gin
self.ratio_gout = ratio_gout
self.global_in_num = in_cg
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
self.convl2l = module(in_cl, out_cl, kernel_size,
stride, padding, dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
self.convl2g = module(in_cl, out_cg, kernel_size,
stride, padding, dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
self.convg2l = module(in_cg, out_cl, kernel_size,
stride, padding, dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
self.convg2g = module(
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
self.gated = gated
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
self.gate = module(in_channels, 2, 1)
def forward(self, x):
x_l, x_g = x if type(x) is tuple else (x, 0)
out_xl, out_xg = 0, 0
if self.gated:
total_input_parts = [x_l]
if torch.is_tensor(x_g):
total_input_parts.append(x_g)
total_input = torch.cat(total_input_parts, dim=1)
gates = torch.sigmoid(self.gate(total_input))
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
else:
g2l_gate, l2g_gate = 1, 1
if self.ratio_gout != 1:
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
if self.ratio_gout != 0:
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
return out_xl, out_xg
class FFC_BN_ACT(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, ratio_gin, ratio_gout,
stride=1, padding=0, dilation=1, groups=1, bias=False,
norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
padding_type='reflect',
enable_lfu=True, **kwargs):
super(FFC_BN_ACT, self).__init__()
self.ffc = FFC(in_channels, out_channels, kernel_size,
ratio_gin, ratio_gout, stride, padding, dilation,
groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
global_channels = int(out_channels * ratio_gout)
self.bn_l = lnorm(out_channels - global_channels)
self.bn_g = gnorm(global_channels)
lact = nn.Identity if ratio_gout == 1 else activation_layer
gact = nn.Identity if ratio_gout == 0 else activation_layer
self.act_l = lact(inplace=True)
self.act_g = gact(inplace=True)
def forward(self, x):
x_l, x_g = self.ffc(x)
x_l = self.act_l(self.bn_l(x_l))
x_g = self.act_g(self.bn_g(x_g))
return x_l, x_g
class FFCResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
spatial_transform_kwargs=None, inline=False, **conv_kwargs):
super().__init__()
self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
norm_layer=norm_layer,
activation_layer=activation_layer,
padding_type=padding_type,
**conv_kwargs)
self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
norm_layer=norm_layer,
activation_layer=activation_layer,
padding_type=padding_type,
**conv_kwargs)
if spatial_transform_kwargs is not None:
self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
self.inline = inline
def forward(self, x):
if self.inline:
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
else:
x_l, x_g = x if type(x) is tuple else (x, 0)
id_l, id_g = x_l, x_g
x_l, x_g = self.conv1((x_l, x_g))
x_l, x_g = self.conv2((x_l, x_g))
x_l, x_g = id_l + x_l, id_g + x_g
out = x_l, x_g
if self.inline:
out = torch.cat(out, dim=1)
return out
class ConcatTupleLayer(nn.Module):
def forward(self, x):
assert isinstance(x, tuple)
x_l, x_g = x
assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
if not torch.is_tensor(x_g):
return x_l
return torch.cat(x, dim=1)
class FFCResNetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect', activation_layer=nn.ReLU,
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True),
init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={},
spatial_transform_layers=None, spatial_transform_kwargs={},
add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
assert (n_blocks >= 0)
super().__init__()
model = [nn.ReflectionPad2d(3),
FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
activation_layer=activation_layer, **init_conv_kwargs)]
### downsample
for i in range(n_downsampling):
mult = 2 ** i
if i == n_downsampling - 1:
cur_conv_kwargs = dict(downsample_conv_kwargs)
cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
else:
cur_conv_kwargs = downsample_conv_kwargs
model += [FFC_BN_ACT(min(max_features, ngf * mult),
min(max_features, ngf * mult * 2),
kernel_size=3, stride=2, padding=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
**cur_conv_kwargs)]
mult = 2 ** n_downsampling
feats_num_bottleneck = min(max_features, ngf * mult)
### resnet blocks
for i in range(n_blocks):
cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
norm_layer=norm_layer, **resnet_conv_kwargs)
if spatial_transform_layers is not None and i in spatial_transform_layers:
cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs)
model += [cur_resblock]
model += [ConcatTupleLayer()]
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
min(max_features, int(ngf * mult / 2)),
kernel_size=3, stride=2, padding=1, output_padding=1),
up_norm_layer(min(max_features, int(ngf * mult / 2))),
up_activation]
if out_ffc:
model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class FFCNLayerDiscriminator(BaseDiscriminator):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, max_features=512,
init_conv_kwargs={}, conv_kwargs={}):
super().__init__()
self.n_layers = n_layers
def _act_ctor(inplace=True):
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
kw = 3
padw = int(np.ceil((kw-1.0)/2))
sequence = [[FFC_BN_ACT(input_nc, ndf, kernel_size=kw, padding=padw, norm_layer=norm_layer,
activation_layer=_act_ctor, **init_conv_kwargs)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, max_features)
cur_model = [
FFC_BN_ACT(nf_prev, nf,
kernel_size=kw, stride=2, padding=padw,
norm_layer=norm_layer,
activation_layer=_act_ctor,
**conv_kwargs)
]
sequence.append(cur_model)
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = [
FFC_BN_ACT(nf_prev, nf,
kernel_size=kw, stride=1, padding=padw,
norm_layer=norm_layer,
activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs),
**conv_kwargs),
ConcatTupleLayer()
]
sequence.append(cur_model)
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
def get_all_activations(self, x):
res = [x]
for n in range(self.n_layers + 2):
model = getattr(self, 'model' + str(n))
res.append(model(res[-1]))
return res[1:]
def forward(self, x):
act = self.get_all_activations(x)
feats = []
for out in act[:-1]:
if isinstance(out, tuple):
if torch.is_tensor(out[1]):
out = torch.cat(out, dim=1)
else:
out = out[0]
feats.append(out)
return act[-1], feats

View File

@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
import random
from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
class MultidilatedConv(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
super().__init__()
convs = []
self.equal_dim = equal_dim
assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
if comb_mode in ('cat_out', 'cat_both'):
self.cat_out = True
if equal_dim:
assert out_dim % dilation_num == 0
out_dims = [out_dim // dilation_num] * dilation_num
self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
else:
out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
out_dims.append(out_dim - sum(out_dims))
index = []
starts = [0] + out_dims[:-1]
lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
for i in range(out_dims[-1]):
for j in range(dilation_num):
index += list(range(starts[j], starts[j] + lengths[j]))
starts[j] += lengths[j]
self.index = index
assert(len(index) == out_dim)
self.out_dims = out_dims
else:
self.cat_out = False
self.out_dims = [out_dim] * dilation_num
if comb_mode in ('cat_in', 'cat_both'):
if equal_dim:
assert in_dim % dilation_num == 0
in_dims = [in_dim // dilation_num] * dilation_num
else:
in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
in_dims.append(in_dim - sum(in_dims))
self.in_dims = in_dims
self.cat_in = True
else:
self.cat_in = False
self.in_dims = [in_dim] * dilation_num
conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
dilation = min_dilation
for i in range(dilation_num):
if isinstance(padding, int):
cur_padding = padding * dilation
else:
cur_padding = padding[i]
convs.append(conv_type(
self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
))
if i > 0 and shared_weights:
convs[-1].weight = convs[0].weight
convs[-1].bias = convs[0].bias
dilation *= 2
self.convs = nn.ModuleList(convs)
self.shuffle_in_channels = shuffle_in_channels
if self.shuffle_in_channels:
# shuffle list as shuffling of tensors is nondeterministic
in_channels_permute = list(range(in_dim))
random.shuffle(in_channels_permute)
# save as buffer so it is saved and loaded with checkpoint
self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
def forward(self, x):
if self.shuffle_in_channels:
x = x[:, self.in_channels_permute]
outs = []
if self.cat_in:
if self.equal_dim:
x = x.chunk(len(self.convs), dim=1)
else:
new_x = []
start = 0
for dim in self.in_dims:
new_x.append(x[:, start:start+dim])
start += dim
x = new_x
for i, conv in enumerate(self.convs):
if self.cat_in:
input = x[i]
else:
input = x
outs.append(conv(input))
if self.cat_out:
out = torch.cat(outs, dim=1)[:, self.index]
else:
out = sum(outs)
return out

View File

@@ -0,0 +1,244 @@
from typing import List, Tuple, Union, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from annotator.lama.saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
from annotator.lama.saicinpainting.training.modules.pix2pixhd import ResnetBlock
class ResNetHead(nn.Module):
def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
assert (n_blocks >= 0)
super(ResNetHead, self).__init__()
conv_layer = get_conv_block_ctor(conv_kind)
model = [nn.ReflectionPad2d(3),
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation]
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2),
activation]
mult = 2 ** n_downsampling
### resnet blocks
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
conv_kind=conv_kind)]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class ResNetTail(nn.Module):
def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
add_in_proj=None):
assert (n_blocks >= 0)
super(ResNetTail, self).__init__()
mult = 2 ** n_downsampling
model = []
if add_in_proj is not None:
model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
### resnet blocks
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
conv_kind=conv_kind)]
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
output_padding=1),
up_norm_layer(int(ngf * mult / 2)),
up_activation]
self.model = nn.Sequential(*model)
out_layers = []
for _ in range(out_extra_layers_n):
out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
up_norm_layer(ngf),
up_activation]
out_layers += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if add_out_act:
out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.out_proj = nn.Sequential(*out_layers)
def forward(self, input, return_last_act=False):
features = self.model(input)
out = self.out_proj(features)
if return_last_act:
return out, features
else:
return out
class MultiscaleResNet(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
out_cumulative=False, return_only_hr=False):
super().__init__()
self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
conv_kind=conv_kind, activation=activation)
for i in range(n_scales)])
tail_in_feats = ngf * (2 ** n_downsampling) + ngf
self.tails = nn.ModuleList([ResNetTail(output_nc,
ngf=ngf, n_downsampling=n_downsampling,
n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
up_activation=up_activation, add_out_act=add_out_act,
out_extra_layers_n=out_extra_layers_n,
add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
for i in range(n_scales)])
self.out_cumulative = out_cumulative
self.return_only_hr = return_only_hr
@property
def num_scales(self):
return len(self.heads)
def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
-> Union[torch.Tensor, List[torch.Tensor]]:
"""
:param ms_inputs: List of inputs of different resolutions from HR to LR
:param smallest_scales_num: int or None, number of smallest scales to take at input
:return: Depending on return_only_hr:
True: Only the most HR output
False: List of outputs of different resolutions from HR to LR
"""
if smallest_scales_num is None:
assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
smallest_scales_num = len(self.heads)
else:
assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
cur_heads = self.heads[-smallest_scales_num:]
ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
all_outputs = []
prev_tail_features = None
for i in range(len(ms_features)):
scale_i = -i - 1
cur_tail_input = ms_features[-i - 1]
if prev_tail_features is not None:
if prev_tail_features.shape != cur_tail_input.shape:
prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
mode='bilinear', align_corners=False)
cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
prev_tail_features = cur_tail_feats
all_outputs.append(cur_out)
if self.out_cumulative:
all_outputs_cum = [all_outputs[0]]
for i in range(1, len(ms_features)):
cur_out = all_outputs[i]
cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
mode='bilinear', align_corners=False)
all_outputs_cum.append(cur_out_cum)
all_outputs = all_outputs_cum
if self.return_only_hr:
return all_outputs[-1]
else:
return all_outputs[::-1]
class MultiscaleDiscriminatorSimple(nn.Module):
def __init__(self, ms_impl):
super().__init__()
self.ms_impl = nn.ModuleList(ms_impl)
@property
def num_scales(self):
return len(self.ms_impl)
def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
-> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
:param ms_inputs: List of inputs of different resolutions from HR to LR
:param smallest_scales_num: int or None, number of smallest scales to take at input
:return: List of pairs (prediction, features) for different resolutions from HR to LR
"""
if smallest_scales_num is None:
assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
smallest_scales_num = len(self.heads)
else:
assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
(len(self.ms_impl), len(ms_inputs), smallest_scales_num)
return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
class SingleToMultiScaleInputMixin:
def forward(self, x: torch.Tensor) -> List:
orig_height, orig_width = x.shape[2:]
factors = [2 ** i for i in range(self.num_scales)]
ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
for f in factors]
return super().forward(ms_inputs)
class GeneratorMultiToSingleOutputMixin:
def forward(self, x):
return super().forward(x)[0]
class DiscriminatorMultiToSingleOutputMixin:
def forward(self, x):
out_feat_tuples = super().forward(x)
return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
class DiscriminatorMultiToSingleOutputStackedMixin:
def __init__(self, *args, return_feats_only_levels=None, **kwargs):
super().__init__(*args, **kwargs)
self.return_feats_only_levels = return_feats_only_levels
def forward(self, x):
out_feat_tuples = super().forward(x)
outs = [out for out, _ in out_feat_tuples]
scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
mode='bilinear', align_corners=False)
for cur_out in outs[1:]]
out = torch.cat(scaled_outs, dim=1)
if self.return_feats_only_levels is not None:
feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
else:
feat_lists = [flist for _, flist in out_feat_tuples]
feats = [f for flist in feat_lists for f in flist]
return out, feats
class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
pass
class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
pass

View File

@@ -0,0 +1,669 @@
# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
import collections
from functools import partial
import functools
import logging
from collections import defaultdict
import numpy as np
import torch.nn as nn
from annotator.lama.saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
from annotator.lama.saicinpainting.training.modules.ffc import FFCResnetBlock
from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
class DotDict(defaultdict):
# https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
"""dot.notation access to dictionary attributes"""
__getattr__ = defaultdict.get
__setattr__ = defaultdict.__setitem__
__delattr__ = defaultdict.__delitem__
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
dilation=1, in_dim=None, groups=1, second_dilation=None):
super(ResnetBlock, self).__init__()
self.in_dim = in_dim
self.dim = dim
if second_dilation is None:
second_dilation = dilation
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
second_dilation=second_dilation)
if self.in_dim is not None:
self.input_conv = nn.Conv2d(in_dim, dim, 1)
self.out_channnels = dim
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
dilation=1, in_dim=None, groups=1, second_dilation=1):
conv_layer = get_conv_block_ctor(conv_kind)
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(dilation)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(dilation)]
elif padding_type == 'zero':
p = dilation
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
if in_dim is None:
in_dim = dim
conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(second_dilation)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(second_dilation)]
elif padding_type == 'zero':
p = second_dilation
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
x_before = x
if self.in_dim is not None:
x = self.input_conv(x)
out = x + self.conv_block(x_before)
return out
class ResnetBlock5x5(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
dilation=1, in_dim=None, groups=1, second_dilation=None):
super(ResnetBlock5x5, self).__init__()
self.in_dim = in_dim
self.dim = dim
if second_dilation is None:
second_dilation = dilation
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
second_dilation=second_dilation)
if self.in_dim is not None:
self.input_conv = nn.Conv2d(in_dim, dim, 1)
self.out_channnels = dim
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
dilation=1, in_dim=None, groups=1, second_dilation=1):
conv_layer = get_conv_block_ctor(conv_kind)
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(dilation * 2)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(dilation * 2)]
elif padding_type == 'zero':
p = dilation * 2
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
if in_dim is None:
in_dim = dim
conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
elif padding_type == 'zero':
p = second_dilation * 2
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
x_before = x
if self.in_dim is not None:
x = self.input_conv(x)
out = x + self.conv_block(x_before)
return out
class MultidilatedResnetBlock(nn.Module):
def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
super().__init__()
self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
conv_block = []
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class MultiDilatedGlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
n_blocks=3, norm_layer=nn.BatchNorm2d,
padding_type='reflect', conv_kind='default',
deconv_kind='convtranspose', activation=nn.ReLU(True),
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
add_out_act=True, max_features=1024, multidilation_kwargs={},
ffc_positions=None, ffc_kwargs={}):
assert (n_blocks >= 0)
super().__init__()
conv_layer = get_conv_block_ctor(conv_kind)
resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
norm_layer = get_norm_layer(norm_layer)
if affine is not None:
norm_layer = partial(norm_layer, affine=affine)
up_norm_layer = get_norm_layer(up_norm_layer)
if affine is not None:
up_norm_layer = partial(up_norm_layer, affine=affine)
model = [nn.ReflectionPad2d(3),
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation]
identity = Identity()
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [conv_layer(min(max_features, ngf * mult),
min(max_features, ngf * mult * 2),
kernel_size=3, stride=2, padding=1),
norm_layer(min(max_features, ngf * mult * 2)),
activation]
mult = 2 ** n_downsampling
feats_num_bottleneck = min(max_features, ngf * mult)
### resnet blocks
for i in range(n_blocks):
if ffc_positions is not None and i in ffc_positions:
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
inline=True, **ffc_kwargs)]
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
conv_layer=resnet_conv_layer, activation=activation,
norm_layer=norm_layer)]
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class ConfigGlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
n_blocks=3, norm_layer=nn.BatchNorm2d,
padding_type='reflect', conv_kind='default',
deconv_kind='convtranspose', activation=nn.ReLU(True),
up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
add_out_act=True, max_features=1024,
manual_block_spec=[],
resnet_block_kind='multidilatedresnetblock',
resnet_conv_kind='multidilated',
resnet_dilation=1,
multidilation_kwargs={}):
assert (n_blocks >= 0)
super().__init__()
conv_layer = get_conv_block_ctor(conv_kind)
resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
norm_layer = get_norm_layer(norm_layer)
if affine is not None:
norm_layer = partial(norm_layer, affine=affine)
up_norm_layer = get_norm_layer(up_norm_layer)
if affine is not None:
up_norm_layer = partial(up_norm_layer, affine=affine)
model = [nn.ReflectionPad2d(3),
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation]
identity = Identity()
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [conv_layer(min(max_features, ngf * mult),
min(max_features, ngf * mult * 2),
kernel_size=3, stride=2, padding=1),
norm_layer(min(max_features, ngf * mult * 2)),
activation]
mult = 2 ** n_downsampling
feats_num_bottleneck = min(max_features, ngf * mult)
if len(manual_block_spec) == 0:
manual_block_spec = [
DotDict(lambda : None, {
'n_blocks': n_blocks,
'use_default': True})
]
### resnet blocks
for block_spec in manual_block_spec:
def make_and_add_blocks(model, block_spec):
block_spec = DotDict(lambda : None, block_spec)
if not block_spec.use_default:
resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
resnet_conv_kind = block_spec.resnet_conv_kind
resnet_block_kind = block_spec.resnet_block_kind
if block_spec.resnet_dilation is not None:
resnet_dilation = block_spec.resnet_dilation
for i in range(block_spec.n_blocks):
if resnet_block_kind == "multidilatedresnetblock":
model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
conv_layer=resnet_conv_layer, activation=activation,
norm_layer=norm_layer)]
if resnet_block_kind == "resnetblock":
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
conv_kind=resnet_conv_kind)]
if resnet_block_kind == "resnetblock5x5":
model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
conv_kind=resnet_conv_kind)]
if resnet_block_kind == "resnetblockdwdil":
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
make_and_add_blocks(model, block_spec)
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
blocks = []
for i in range(dilated_blocks_n):
if dilation_block_kind == 'simple':
blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
elif dilation_block_kind == 'multi':
blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
else:
raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
return blocks
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
up_norm_layer=nn.BatchNorm2d, affine=None,
up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
dilated_blocks_n_middle=0,
add_out_act=True,
max_features=1024, is_resblock_depthwise=False,
ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
dilation_block_kind='simple', multidilation_kwargs={}):
assert (n_blocks >= 0)
super().__init__()
conv_layer = get_conv_block_ctor(conv_kind)
norm_layer = get_norm_layer(norm_layer)
if affine is not None:
norm_layer = partial(norm_layer, affine=affine)
up_norm_layer = get_norm_layer(up_norm_layer)
if affine is not None:
up_norm_layer = partial(up_norm_layer, affine=affine)
if ffc_positions is not None:
ffc_positions = collections.Counter(ffc_positions)
model = [nn.ReflectionPad2d(3),
conv_layer(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
activation]
identity = Identity()
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [conv_layer(min(max_features, ngf * mult),
min(max_features, ngf * mult * 2),
kernel_size=3, stride=2, padding=1),
norm_layer(min(max_features, ngf * mult * 2)),
activation]
mult = 2 ** n_downsampling
feats_num_bottleneck = min(max_features, ngf * mult)
dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
activation=activation, norm_layer=norm_layer)
if dilation_block_kind == 'simple':
dilated_block_kwargs['conv_kind'] = conv_kind
elif dilation_block_kind == 'multi':
dilated_block_kwargs['conv_layer'] = functools.partial(
get_conv_block_ctor('multidilated'), **multidilation_kwargs)
# dilated blocks at the start of the bottleneck sausage
if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
# resnet blocks
for i in range(n_blocks):
# dilated blocks at the middle of the bottleneck sausage
if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
if ffc_positions is not None and i in ffc_positions:
for _ in range(ffc_positions[i]): # same position can occur more than once
model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
inline=True, **ffc_kwargs)]
if is_resblock_depthwise:
resblock_groups = feats_num_bottleneck
else:
resblock_groups = 1
model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
dilation=dilation, second_dilation=second_dilation)]
# dilated blocks at the end of the bottleneck sausage
if dilated_blocks_n is not None and dilated_blocks_n > 0:
model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
# upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
min(max_features, int(ngf * mult / 2)),
kernel_size=3, stride=2, padding=1, output_padding=1),
up_norm_layer(min(max_features, int(ngf * mult / 2))),
up_activation]
model += [nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class GlobalGeneratorGated(GlobalGenerator):
def __init__(self, *args, **kwargs):
real_kwargs=dict(
conv_kind='gated_bn_relu',
activation=nn.Identity(),
norm_layer=nn.Identity
)
real_kwargs.update(kwargs)
super().__init__(*args, **real_kwargs)
class GlobalGeneratorFromSuperChannels(nn.Module):
def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
super().__init__()
self.n_downsampling = n_downsampling
norm_layer = get_norm_layer(norm_layer)
if type(norm_layer) == functools.partial:
use_bias = (norm_layer.func == nn.InstanceNorm2d)
else:
use_bias = (norm_layer == nn.InstanceNorm2d)
channels = self.convert_super_channels(super_channels)
self.channels = channels
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
norm_layer(channels[0]),
nn.ReLU(True)]
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(channels[1+i]),
nn.ReLU(True)]
mult = 2 ** n_downsampling
n_blocks1 = n_blocks // 3
n_blocks2 = n_blocks1
n_blocks3 = n_blocks - n_blocks1 - n_blocks2
for i in range(n_blocks1):
c = n_downsampling
dim = channels[c]
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
for i in range(n_blocks2):
c = n_downsampling+1
dim = channels[c]
kwargs = {}
if i == 0:
kwargs = {"in_dim": channels[c-1]}
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
for i in range(n_blocks3):
c = n_downsampling+2
dim = channels[c]
kwargs = {}
if i == 0:
kwargs = {"in_dim": channels[c-1]}
model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
channels[n_downsampling+3+i+1],
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(channels[n_downsampling+3+i+1]),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
if add_out_act:
model.append(get_activation('tanh' if add_out_act is True else add_out_act))
self.model = nn.Sequential(*model)
def convert_super_channels(self, super_channels):
n_downsampling = self.n_downsampling
result = []
cnt = 0
if n_downsampling == 2:
N1 = 10
elif n_downsampling == 3:
N1 = 13
else:
raise NotImplementedError
for i in range(0, N1):
if i in [1,4,7,10]:
channel = super_channels[cnt] * (2 ** cnt)
config = {'channel': channel}
result.append(channel)
logging.info(f"Downsample channels {result[-1]}")
cnt += 1
for i in range(3):
for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
if len(super_channels) == 6:
channel = super_channels[3] * 4
else:
channel = super_channels[i + 3] * 4
config = {'channel': channel}
if counter == 0:
result.append(channel)
logging.info(f"Bottleneck channels {result[-1]}")
cnt = 2
for i in range(N1+9, N1+21):
if i in [22, 25,28]:
cnt -= 1
if len(super_channels) == 6:
channel = super_channels[5 - cnt] * (2 ** cnt)
else:
channel = super_channels[7 - cnt] * (2 ** cnt)
result.append(int(channel))
logging.info(f"Upsample channels {result[-1]}")
return result
def forward(self, input):
return self.model(input)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(BaseDiscriminator):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
super().__init__()
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = []
cur_model += [
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]
sequence.append(cur_model)
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = []
cur_model += [
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]
sequence.append(cur_model)
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
def get_all_activations(self, x):
res = [x]
for n in range(self.n_layers + 2):
model = getattr(self, 'model' + str(n))
res.append(model(res[-1]))
return res[1:]
def forward(self, x):
act = self.get_all_activations(x)
return act[-1], act[:-1]
class MultidilatedNLayerDiscriminator(BaseDiscriminator):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
super().__init__()
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = []
cur_model += [
MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]
sequence.append(cur_model)
nf_prev = nf
nf = min(nf * 2, 512)
cur_model = []
cur_model += [
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]
sequence.append(cur_model)
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
def get_all_activations(self, x):
res = [x]
for n in range(self.n_layers + 2):
model = getattr(self, 'model' + str(n))
res.append(model(res[-1]))
return res[1:]
def forward(self, x):
act = self.get_all_activations(x)
return act[-1], act[:-1]
class NLayerDiscriminatorAsGen(NLayerDiscriminator):
def forward(self, x):
return super().forward(x)[0]

View File

@@ -0,0 +1,49 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.geometry.transform import rotate
class LearnableSpatialTransformWrapper(nn.Module):
def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
super().__init__()
self.impl = impl
self.angle = torch.rand(1) * angle_init_range
if train_angle:
self.angle = nn.Parameter(self.angle, requires_grad=True)
self.pad_coef = pad_coef
def forward(self, x):
if torch.is_tensor(x):
return self.inverse_transform(self.impl(self.transform(x)), x)
elif isinstance(x, tuple):
x_trans = tuple(self.transform(elem) for elem in x)
y_trans = self.impl(x_trans)
return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
else:
raise ValueError(f'Unexpected input type {type(x)}')
def transform(self, x):
height, width = x.shape[2:]
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
return x_padded_rotated
def inverse_transform(self, y_padded_rotated, orig_x):
height, width = orig_x.shape[2:]
pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
y_height, y_width = y_padded.shape[2:]
y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
return y
if __name__ == '__main__':
layer = LearnableSpatialTransformWrapper(nn.Identity())
x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
y = layer(x)
assert x.shape == y.shape
assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
print('all ok')

View File

@@ -0,0 +1,20 @@
import torch.nn as nn
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
res = x * y.expand_as(x)
return res

View File

@@ -0,0 +1,29 @@
import logging
import torch
from annotator.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
def get_training_model_class(kind):
if kind == 'default':
return DefaultInpaintingTrainingModule
raise ValueError(f'Unknown trainer module {kind}')
def make_training_model(config):
kind = config.training_model.kind
kwargs = dict(config.training_model)
kwargs.pop('kind')
kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
logging.info(f'Make training model {kind}')
cls = get_training_model_class(kind)
return cls(config, **kwargs)
def load_checkpoint(train_config, path, map_location='cuda', strict=True):
model = make_training_model(train_config).generator
state = torch.load(path, map_location=map_location)
model.load_state_dict(state, strict=strict)
return model

View File

@@ -0,0 +1,293 @@
import copy
import logging
from typing import Dict, Tuple
import pandas as pd
import pytorch_lightning as ptl
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.utils.data import DistributedSampler
# from annotator.lama.saicinpainting.evaluation import make_evaluator
# from annotator.lama.saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader
# from annotator.lama.saicinpainting.training.losses.adversarial import make_discrim_loss
# from annotator.lama.saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL
from annotator.lama.saicinpainting.training.modules import make_generator #, make_discriminator
# from annotator.lama.saicinpainting.training.visualizers import make_visualizer
from annotator.lama.saicinpainting.utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, \
get_has_ddp_rank
LOGGER = logging.getLogger(__name__)
def make_optimizer(parameters, kind='adamw', **kwargs):
if kind == 'adam':
optimizer_class = torch.optim.Adam
elif kind == 'adamw':
optimizer_class = torch.optim.AdamW
else:
raise ValueError(f'Unknown optimizer kind {kind}')
return optimizer_class(parameters, **kwargs)
def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999):
with torch.no_grad():
res_params = dict(result.named_parameters())
new_params = dict(new_iterate_model.named_parameters())
for k in res_params.keys():
res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay)
def make_multiscale_noise(base_tensor, scales=6, scale_mode='bilinear'):
batch_size, _, height, width = base_tensor.shape
cur_height, cur_width = height, width
result = []
align_corners = False if scale_mode in ('bilinear', 'bicubic') else None
for _ in range(scales):
cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device)
cur_sample_scaled = F.interpolate(cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners)
result.append(cur_sample_scaled)
cur_height //= 2
cur_width //= 2
return torch.cat(result, dim=1)
class BaseInpaintingTrainingModule(ptl.LightningModule):
def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100,
average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000,
average_generator_period=10, store_discr_outputs_for_vis=False,
**kwargs):
super().__init__(*args, **kwargs)
LOGGER.info('BaseInpaintingTrainingModule init called')
self.config = config
self.generator = make_generator(config, **self.config.generator)
self.use_ddp = use_ddp
if not get_has_ddp_rank():
LOGGER.info(f'Generator\n{self.generator}')
# if not predict_only:
# self.save_hyperparameters(self.config)
# self.discriminator = make_discriminator(**self.config.discriminator)
# self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial)
# self.visualizer = make_visualizer(**self.config.visualizer)
# self.val_evaluator = make_evaluator(**self.config.evaluator)
# self.test_evaluator = make_evaluator(**self.config.evaluator)
#
# if not get_has_ddp_rank():
# LOGGER.info(f'Discriminator\n{self.discriminator}')
#
# extra_val = self.config.data.get('extra_val', ())
# if extra_val:
# self.extra_val_titles = list(extra_val)
# self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator)
# for k in extra_val})
# else:
# self.extra_evaluators = {}
#
# self.average_generator = average_generator
# self.generator_avg_beta = generator_avg_beta
# self.average_generator_start_step = average_generator_start_step
# self.average_generator_period = average_generator_period
# self.generator_average = None
# self.last_generator_averaging_step = -1
# self.store_discr_outputs_for_vis = store_discr_outputs_for_vis
#
# if self.config.losses.get("l1", {"weight_known": 0})['weight_known'] > 0:
# self.loss_l1 = nn.L1Loss(reduction='none')
#
# if self.config.losses.get("mse", {"weight": 0})['weight'] > 0:
# self.loss_mse = nn.MSELoss(reduction='none')
#
# if self.config.losses.perceptual.weight > 0:
# self.loss_pl = PerceptualLoss()
#
# # if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0:
# # self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl)
# # else:
# # self.loss_resnet_pl = None
#
# self.loss_resnet_pl = None
self.visualize_each_iters = visualize_each_iters
LOGGER.info('BaseInpaintingTrainingModule init done')
def configure_optimizers(self):
discriminator_params = list(self.discriminator.parameters())
return [
dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)),
dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)),
]
def train_dataloader(self):
kwargs = dict(self.config.data.train)
if self.use_ddp:
kwargs['ddp_kwargs'] = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
rank=self.trainer.global_rank,
shuffle=True)
dataloader = make_default_train_dataloader(**self.config.data.train)
return dataloader
def val_dataloader(self):
res = [make_default_val_dataloader(**self.config.data.val)]
if self.config.data.visual_test is not None:
res = res + [make_default_val_dataloader(**self.config.data.visual_test)]
else:
res = res + res
extra_val = self.config.data.get('extra_val', ())
if extra_val:
res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles]
return res
def training_step(self, batch, batch_idx, optimizer_idx=None):
self._is_training_step = True
return self._do_step(batch, batch_idx, mode='train', optimizer_idx=optimizer_idx)
def validation_step(self, batch, batch_idx, dataloader_idx):
extra_val_key = None
if dataloader_idx == 0:
mode = 'val'
elif dataloader_idx == 1:
mode = 'test'
else:
mode = 'extra_val'
extra_val_key = self.extra_val_titles[dataloader_idx - 2]
self._is_training_step = False
return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key)
def training_step_end(self, batch_parts_outputs):
if self.training and self.average_generator \
and self.global_step >= self.average_generator_start_step \
and self.global_step >= self.last_generator_averaging_step + self.average_generator_period:
if self.generator_average is None:
self.generator_average = copy.deepcopy(self.generator)
else:
update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta)
self.last_generator_averaging_step = self.global_step
full_loss = (batch_parts_outputs['loss'].mean()
if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used
else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True))
log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()}
self.log_dict(log_info, on_step=True, on_epoch=False)
return full_loss
def validation_epoch_end(self, outputs):
outputs = [step_out for out_group in outputs for step_out in out_group]
averaged_logs = average_dicts(step_out['log_info'] for step_out in outputs)
self.log_dict({k: v.mean() for k, v in averaged_logs.items()})
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
# standard validation
val_evaluator_states = [s['val_evaluator_state'] for s in outputs if 'val_evaluator_state' in s]
val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states)
val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0)
val_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
LOGGER.info(f'Validation metrics after epoch #{self.current_epoch}, '
f'total {self.global_step} iterations:\n{val_evaluator_res_df}')
for k, v in flatten_dict(val_evaluator_res).items():
self.log(f'val_{k}', v)
# standard visual test
test_evaluator_states = [s['test_evaluator_state'] for s in outputs
if 'test_evaluator_state' in s]
test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states)
test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0)
test_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
LOGGER.info(f'Test metrics after epoch #{self.current_epoch}, '
f'total {self.global_step} iterations:\n{test_evaluator_res_df}')
for k, v in flatten_dict(test_evaluator_res).items():
self.log(f'test_{k}', v)
# extra validations
if self.extra_evaluators:
for cur_eval_title, cur_evaluator in self.extra_evaluators.items():
cur_state_key = f'extra_val_{cur_eval_title}_evaluator_state'
cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s]
cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states)
cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0)
cur_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
LOGGER.info(f'Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, '
f'total {self.global_step} iterations:\n{cur_evaluator_res_df}')
for k, v in flatten_dict(cur_evaluator_res).items():
self.log(f'extra_val_{cur_eval_title}_{k}', v)
def _do_step(self, batch, batch_idx, mode='train', optimizer_idx=None, extra_val_key=None):
if optimizer_idx == 0: # step for generator
set_requires_grad(self.generator, True)
set_requires_grad(self.discriminator, False)
elif optimizer_idx == 1: # step for discriminator
set_requires_grad(self.generator, False)
set_requires_grad(self.discriminator, True)
batch = self(batch)
total_loss = 0
metrics = {}
if optimizer_idx is None or optimizer_idx == 0: # step for generator
total_loss, metrics = self.generator_loss(batch)
elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator
if self.config.losses.adversarial.weight > 0:
total_loss, metrics = self.discriminator_loss(batch)
if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == 'test'):
if self.config.losses.adversarial.weight > 0:
if self.store_discr_outputs_for_vis:
with torch.no_grad():
self.store_discr_outputs(batch)
vis_suffix = f'_{mode}'
if mode == 'extra_val':
vis_suffix += f'_{extra_val_key}'
self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix)
metrics_prefix = f'{mode}_'
if mode == 'extra_val':
metrics_prefix += f'{extra_val_key}_'
result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix))
if mode == 'val':
result['val_evaluator_state'] = self.val_evaluator.process_batch(batch)
elif mode == 'test':
result['test_evaluator_state'] = self.test_evaluator.process_batch(batch)
elif mode == 'extra_val':
result[f'extra_val_{extra_val_key}_evaluator_state'] = self.extra_evaluators[extra_val_key].process_batch(batch)
return result
def get_current_generator(self, no_average=False):
if not no_average and not self.training and self.average_generator and self.generator_average is not None:
return self.generator_average
return self.generator
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys"""
raise NotImplementedError()
def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
raise NotImplementedError()
def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
raise NotImplementedError()
def store_discr_outputs(self, batch):
out_size = batch['image'].shape[2:]
discr_real_out, _ = self.discriminator(batch['image'])
discr_fake_out, _ = self.discriminator(batch['predicted_image'])
batch['discr_output_real'] = F.interpolate(discr_real_out, size=out_size, mode='nearest')
batch['discr_output_fake'] = F.interpolate(discr_fake_out, size=out_size, mode='nearest')
batch['discr_output_diff'] = batch['discr_output_real'] - batch['discr_output_fake']
def get_ddp_rank(self):
return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None

View File

@@ -0,0 +1,175 @@
import logging
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
# from annotator.lama.saicinpainting.training.data.datasets import make_constant_area_crop_params
from annotator.lama.saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter
from annotator.lama.saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss
# from annotator.lama.saicinpainting.training.modules.fake_fakes import FakeFakesGenerator
from annotator.lama.saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise
from annotator.lama.saicinpainting.utils import add_prefix_to_keys, get_ramp
LOGGER = logging.getLogger(__name__)
def make_constant_area_crop_batch(batch, **kwargs):
crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
img_width=batch['image'].shape[3],
**kwargs)
batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width]
batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width]
return batch
class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
**kwargs):
super().__init__(*args, **kwargs)
self.concat_mask = concat_mask
self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
self.image_to_discriminator = image_to_discriminator
self.add_noise_kwargs = add_noise_kwargs
self.noise_fill_hole = noise_fill_hole
self.const_area_crop_kwargs = const_area_crop_kwargs
self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
if distance_weighter_kwargs is not None else None
self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
self.fake_fakes_proba = fake_fakes_proba
if self.fake_fakes_proba > 1e-3:
self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
def forward(self, batch):
if self.training and self.rescale_size_getter is not None:
cur_size = self.rescale_size_getter(self.global_step)
batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
if self.training and self.const_area_crop_kwargs is not None:
batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
img = batch['image']
mask = batch['mask']
masked_img = img * (1 - mask)
if self.add_noise_kwargs is not None:
noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs)
if self.noise_fill_hole:
masked_img = masked_img + mask * noise[:, :masked_img.shape[1]]
masked_img = torch.cat([masked_img, noise], dim=1)
if self.concat_mask:
masked_img = torch.cat([masked_img, mask], dim=1)
batch['predicted_image'] = self.generator(masked_img)
batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image']
if self.fake_fakes_proba > 1e-3:
if self.training and torch.rand(1).item() < self.fake_fakes_proba:
batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask)
batch['use_fake_fakes'] = True
else:
batch['fake_fakes'] = torch.zeros_like(img)
batch['fake_fakes_masks'] = torch.zeros_like(mask)
batch['use_fake_fakes'] = False
batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \
if self.refine_mask_for_losses is not None and self.training \
else mask
return batch
def generator_loss(self, batch):
img = batch['image']
predicted_img = batch[self.image_to_discriminator]
original_mask = batch['mask']
supervised_mask = batch['mask_for_losses']
# L1
l1_value = masked_l1_loss(predicted_img, img, supervised_mask,
self.config.losses.l1.weight_known,
self.config.losses.l1.weight_missing)
total_loss = l1_value
metrics = dict(gen_l1=l1_value)
# vgg-based perceptual loss
if self.config.losses.perceptual.weight > 0:
pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight
total_loss = total_loss + pl_value
metrics['gen_pl'] = pl_value
# discriminator
# adversarial_loss calls backward by itself
mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img,
generator=self.generator, discriminator=self.discriminator)
discr_real_pred, discr_real_features = self.discriminator(img)
discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img,
fake_batch=predicted_img,
discr_real_pred=discr_real_pred,
discr_fake_pred=discr_fake_pred,
mask=mask_for_discr)
total_loss = total_loss + adv_gen_loss
metrics['gen_adv'] = adv_gen_loss
metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
# feature matching
if self.config.losses.feature_matching.weight > 0:
need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False)
mask_for_fm = supervised_mask if need_mask_in_fm else None
fm_value = feature_matching_loss(discr_fake_features, discr_real_features,
mask=mask_for_fm) * self.config.losses.feature_matching.weight
total_loss = total_loss + fm_value
metrics['gen_fm'] = fm_value
if self.loss_resnet_pl is not None:
resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
total_loss = total_loss + resnet_pl_value
metrics['gen_resnet_pl'] = resnet_pl_value
return total_loss, metrics
def discriminator_loss(self, batch):
total_loss = 0
metrics = {}
predicted_img = batch[self.image_to_discriminator].detach()
self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img,
generator=self.generator, discriminator=self.discriminator)
discr_real_pred, discr_real_features = self.discriminator(batch['image'])
discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'],
fake_batch=predicted_img,
discr_real_pred=discr_real_pred,
discr_fake_pred=discr_fake_pred,
mask=batch['mask'])
total_loss = total_loss + adv_discr_loss
metrics['discr_adv'] = adv_discr_loss
metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
if batch.get('use_fake_fakes', False):
fake_fakes = batch['fake_fakes']
self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes,
generator=self.generator, discriminator=self.discriminator)
discr_fake_fakes_pred, _ = self.discriminator(fake_fakes)
fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss(
real_batch=batch['image'],
fake_batch=fake_fakes,
discr_real_pred=discr_real_pred,
discr_fake_pred=discr_fake_fakes_pred,
mask=batch['mask']
)
total_loss = total_loss + fake_fakes_adv_discr_loss
metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
return total_loss, metrics

View File

@@ -0,0 +1,15 @@
import logging
from annotator.lama.saicinpainting.training.visualizers.directory import DirectoryVisualizer
from annotator.lama.saicinpainting.training.visualizers.noop import NoopVisualizer
def make_visualizer(kind, **kwargs):
logging.info(f'Make visualizer {kind}')
if kind == 'directory':
return DirectoryVisualizer(**kwargs)
if kind == 'noop':
return NoopVisualizer()
raise ValueError(f'Unknown visualizer kind {kind}')

View File

@@ -0,0 +1,73 @@
import abc
from typing import Dict, List
import numpy as np
import torch
from skimage import color
from skimage.segmentation import mark_boundaries
from . import colors
COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation
class BaseVisualizer:
@abc.abstractmethod
def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
"""
Take a batch, make an image from it and visualize
"""
raise NotImplementedError()
def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str],
last_without_mask=True, rescale_keys=None, mask_only_first=None,
black_mask=False) -> np.ndarray:
mask = images_dict['mask'] > 0.5
result = []
for i, k in enumerate(keys):
img = images_dict[k]
img = np.transpose(img, (1, 2, 0))
if rescale_keys is not None and k in rescale_keys:
img = img - img.min()
img /= img.max() + 1e-5
if len(img.shape) == 2:
img = np.expand_dims(img, 2)
if img.shape[2] == 1:
img = np.repeat(img, 3, axis=2)
elif (img.shape[2] > 3):
img_classes = img.argmax(2)
img = color.label2rgb(img_classes, colors=COLORS)
if mask_only_first:
need_mark_boundaries = i == 0
else:
need_mark_boundaries = i < len(keys) - 1 or not last_without_mask
if need_mark_boundaries:
if black_mask:
img = img * (1 - mask[0][..., None])
img = mark_boundaries(img,
mask[0],
color=(1., 0., 0.),
outline_color=(1., 1., 1.),
mode='thick')
result.append(img)
return np.concatenate(result, axis=1)
def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10,
last_without_mask=True, rescale_keys=None) -> np.ndarray:
batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items()
if k in keys or k == 'mask'}
batch_size = next(iter(batch.values())).shape[0]
items_to_vis = min(batch_size, max_items)
result = []
for i in range(items_to_vis):
cur_dct = {k: tens[i] for k, tens in batch.items()}
result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask,
rescale_keys=rescale_keys))
return np.concatenate(result, axis=0)

View File

@@ -0,0 +1,76 @@
import random
import colorsys
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False):
# https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
"""
Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
:param nlabels: Number of labels (size of colormap)
:param type: 'bright' for strong colors, 'soft' for pastel colors
:param first_color_black: Option to use first color as black, True or False
:param last_color_black: Option to use last color as black, True or False
:param verbose: Prints the number of labels and shows the colormap. True or False
:return: colormap for matplotlib
"""
if type not in ('bright', 'soft'):
print ('Please choose "bright" or "soft" for type')
return
if verbose:
print('Number of labels: ' + str(nlabels))
# Generate color map for bright colors, based on hsv
if type == 'bright':
randHSVcolors = [(np.random.uniform(low=0.0, high=1),
np.random.uniform(low=0.2, high=1),
np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]
# Convert HSV list to RGB
randRGBcolors = []
for HSVcolor in randHSVcolors:
randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))
if first_color_black:
randRGBcolors[0] = [0, 0, 0]
if last_color_black:
randRGBcolors[-1] = [0, 0, 0]
random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
# Generate soft pastel colors, by limiting the RGB spectrum
if type == 'soft':
low = 0.6
high = 0.95
randRGBcolors = [(np.random.uniform(low=low, high=high),
np.random.uniform(low=low, high=high),
np.random.uniform(low=low, high=high)) for i in range(nlabels)]
if first_color_black:
randRGBcolors[0] = [0, 0, 0]
if last_color_black:
randRGBcolors[-1] = [0, 0, 0]
random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
# Display colorbar
if verbose:
from matplotlib import colors, colorbar
from matplotlib import pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))
bounds = np.linspace(0, nlabels, nlabels + 1)
norm = colors.BoundaryNorm(bounds, nlabels)
cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
boundaries=bounds, format='%1i', orientation=u'horizontal')
return randRGBcolors, random_colormap

View File

@@ -0,0 +1,36 @@
import os
import cv2
import numpy as np
from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch
from annotator.lama.saicinpainting.utils import check_and_warn_input_range
class DirectoryVisualizer(BaseVisualizer):
DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ')
def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10,
last_without_mask=True, rescale_keys=None):
self.outdir = outdir
os.makedirs(self.outdir, exist_ok=True)
self.key_order = key_order
self.max_items_in_batch = max_items_in_batch
self.last_without_mask = last_without_mask
self.rescale_keys = rescale_keys
def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image')
vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch,
last_without_mask=self.last_without_mask,
rescale_keys=self.rescale_keys)
vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}')
os.makedirs(curoutdir, exist_ok=True)
rank_suffix = f'_r{rank}' if rank is not None else ''
out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg')
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(out_fname, vis_img)

View File

@@ -0,0 +1,9 @@
from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer
class NoopVisualizer(BaseVisualizer):
def __init__(self, *args, **kwargs):
pass
def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
pass

View File

@@ -0,0 +1,174 @@
import bisect
import functools
import logging
import numbers
import os
import signal
import sys
import traceback
import warnings
import torch
from pytorch_lightning import seed_everything
LOGGER = logging.getLogger(__name__)
def check_and_warn_input_range(tensor, min_value, max_value, name):
actual_min = tensor.min()
actual_max = tensor.max()
if actual_min < min_value or actual_max > max_value:
warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
for k, v in cur_dict.items():
target_key = prefix + k
target[target_key] = target.get(target_key, default) + v
def average_dicts(dict_list):
result = {}
norm = 1e-3
for dct in dict_list:
sum_dict_with_prefix(result, dct, '')
norm += 1
for k in list(result):
result[k] /= norm
return result
def add_prefix_to_keys(dct, prefix):
return {prefix + k: v for k, v in dct.items()}
def set_requires_grad(module, value):
for param in module.parameters():
param.requires_grad = value
def flatten_dict(dct):
result = {}
for k, v in dct.items():
if isinstance(k, tuple):
k = '_'.join(k)
if isinstance(v, dict):
for sub_k, sub_v in flatten_dict(v).items():
result[f'{k}_{sub_k}'] = sub_v
else:
result[k] = v
return result
class LinearRamp:
def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
self.start_value = start_value
self.end_value = end_value
self.start_iter = start_iter
self.end_iter = end_iter
def __call__(self, i):
if i < self.start_iter:
return self.start_value
if i >= self.end_iter:
return self.end_value
part = (i - self.start_iter) / (self.end_iter - self.start_iter)
return self.start_value * (1 - part) + self.end_value * part
class LadderRamp:
def __init__(self, start_iters, values):
self.start_iters = start_iters
self.values = values
assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
def __call__(self, i):
segment_i = bisect.bisect_right(self.start_iters, i)
return self.values[segment_i]
def get_ramp(kind='ladder', **kwargs):
if kind == 'linear':
return LinearRamp(**kwargs)
if kind == 'ladder':
return LadderRamp(**kwargs)
raise ValueError(f'Unexpected ramp kind: {kind}')
def print_traceback_handler(sig, frame):
LOGGER.warning(f'Received signal {sig}')
bt = ''.join(traceback.format_stack())
LOGGER.warning(f'Requested stack trace:\n{bt}')
def register_debug_signal_handlers(sig=None, handler=print_traceback_handler):
LOGGER.warning(f'Setting signal {sig} handler {handler}')
signal.signal(sig, handler)
def handle_deterministic_config(config):
seed = dict(config).get('seed', None)
if seed is None:
return False
seed_everything(seed)
return True
def get_shape(t):
if torch.is_tensor(t):
return tuple(t.shape)
elif isinstance(t, dict):
return {n: get_shape(q) for n, q in t.items()}
elif isinstance(t, (list, tuple)):
return [get_shape(q) for q in t]
elif isinstance(t, numbers.Number):
return type(t)
else:
raise ValueError('unexpected type {}'.format(type(t)))
def get_has_ddp_rank():
master_port = os.environ.get('MASTER_PORT', None)
node_rank = os.environ.get('NODE_RANK', None)
local_rank = os.environ.get('LOCAL_RANK', None)
world_size = os.environ.get('WORLD_SIZE', None)
has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
return has_rank
def handle_ddp_subprocess():
def main_decorator(main_func):
@functools.wraps(main_func)
def new_main(*args, **kwargs):
# Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
has_parent = parent_cwd is not None
has_rank = get_has_ddp_rank()
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
if has_parent:
# we are in the worker
sys.argv.extend([
f'hydra.run.dir={parent_cwd}',
# 'hydra/hydra_logging=disabled',
# 'hydra/job_logging=disabled'
])
# do nothing if this is a top-level process
# TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
main_func(*args, **kwargs)
return new_main
return main_decorator
def handle_ddp_parent_process():
parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
has_parent = parent_cwd is not None
has_rank = get_has_ddp_rank()
assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
if parent_cwd is None:
os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
return has_parent

View File

@@ -0,0 +1,157 @@
run_title: b18_ffc075_batch8x15
training_model:
kind: default
visualize_each_iters: 1000
concat_mask: true
store_discr_outputs_for_vis: true
losses:
l1:
weight_missing: 0
weight_known: 10
perceptual:
weight: 0
adversarial:
kind: r1
weight: 10
gp_coef: 0.001
mask_as_fake_target: true
allow_scale_mask: true
feature_matching:
weight: 100
resnet_pl:
weight: 30
weights_path: ${env:TORCH_HOME}
optimizers:
generator:
kind: adam
lr: 0.001
discriminator:
kind: adam
lr: 0.0001
visualizer:
key_order:
- image
- predicted_image
- discr_output_fake
- discr_output_real
- inpainted
rescale_keys:
- discr_output_fake
- discr_output_real
kind: directory
outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
location:
data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
data:
batch_size: 15
val_batch_size: 2
num_workers: 3
train:
indir: ${location.data_root_dir}/train
out_size: 256
mask_gen_kwargs:
irregular_proba: 1
irregular_kwargs:
max_angle: 4
max_len: 200
max_width: 100
max_times: 5
min_times: 1
box_proba: 1
box_kwargs:
margin: 10
bbox_min_size: 30
bbox_max_size: 150
max_times: 3
min_times: 1
segm_proba: 0
segm_kwargs:
confidence_threshold: 0.5
max_object_area: 0.5
min_mask_area: 0.07
downsample_levels: 6
num_variants_per_mask: 1
rigidness_mode: 1
max_foreground_coverage: 0.3
max_foreground_intersection: 0.7
max_mask_intersection: 0.1
max_hidden_area: 0.1
max_scale_change: 0.25
horizontal_flip: true
max_vertical_shift: 0.2
position_shuffle: true
transform_variant: distortions
dataloader_kwargs:
batch_size: ${data.batch_size}
shuffle: true
num_workers: ${data.num_workers}
val:
indir: ${location.data_root_dir}/val
img_suffix: .png
dataloader_kwargs:
batch_size: ${data.val_batch_size}
shuffle: false
num_workers: ${data.num_workers}
visual_test:
indir: ${location.data_root_dir}/korean_test
img_suffix: _input.png
pad_out_to_modulo: 32
dataloader_kwargs:
batch_size: 1
shuffle: false
num_workers: ${data.num_workers}
generator:
kind: ffc_resnet
input_nc: 4
output_nc: 3
ngf: 64
n_downsampling: 3
n_blocks: 18
add_out_act: sigmoid
init_conv_kwargs:
ratio_gin: 0
ratio_gout: 0
enable_lfu: false
downsample_conv_kwargs:
ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
enable_lfu: false
resnet_conv_kwargs:
ratio_gin: 0.75
ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
enable_lfu: false
discriminator:
kind: pix2pixhd_nlayer
input_nc: 3
ndf: 64
n_layers: 4
evaluator:
kind: default
inpainted_key: inpainted
integral_kind: ssim_fid100_f1
trainer:
kwargs:
gpus: -1
accelerator: ddp
max_epochs: 200
gradient_clip_val: 1
log_gpu_memory: None
limit_train_batches: 25000
val_check_interval: ${trainer.kwargs.limit_train_batches}
log_every_n_steps: 1000
precision: 32
terminate_on_nan: false
check_val_every_n_epoch: 1
num_sanity_val_steps: 8
limit_val_batches: 1000
replace_sampler_ddp: false
checkpoint_kwargs:
verbose: true
save_top_k: 5
save_last: true
period: 1
monitor: val_ssim_fid100_f1_total_mean
mode: max

View File

@@ -0,0 +1,167 @@
import os
import cv2
import torch
import numpy as np
import yaml
import einops
from omegaconf import OmegaConf
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
from modules_forge.utils import numpy_to_pytorch, resize_image_with_pad
from modules_forge.shared import preprocessor_dir, add_supported_preprocessor
from modules.modelloader import load_file_from_url
from annotator.lama.saicinpainting.training.trainers import load_checkpoint
class PreprocessorInpaint(Preprocessor):
def __init__(self):
super().__init__()
self.name = 'inpaint_global_harmonious'
self.tags = ['Inpaint']
self.model_filename_filters = ['inpaint']
self.slider_resolution = PreprocessorParameter(visible=False)
self.fill_mask_with_one_when_resize_and_fill = True
self.expand_mask_when_resize_and_fill = True
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
mask = mask.round()
mixed_cond = cond * (1.0 - mask) - mask
return mixed_cond, None
class PreprocessorInpaintOnly(PreprocessorInpaint):
def __init__(self):
super().__init__()
self.name = 'inpaint_only'
self.image = None
self.mask = None
self.latent = None
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
mask = mask.round()
self.image = cond
self.mask = mask
vae = process.sd_model.forge_objects.vae
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
latent_image = vae.encode(self.image.movedim(1, -1))
latent_image = process.sd_model.forge_objects.vae.first_stage_model.process_in(latent_image)
B, C, H, W = latent_image.shape
latent_mask = self.mask
latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round()
latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent_image)
unet = process.sd_model.forge_objects.unet.clone()
def pre_cfg(model, c, uc, x, timestep, model_options):
noisy_latent = latent_image.to(x) + timestep[:, None, None, None].to(x) * torch.randn_like(latent_image).to(x)
x = x * latent_mask.to(x) + noisy_latent.to(x) * (1.0 - latent_mask.to(x))
return model, c, uc, x, timestep, model_options
def post_cfg(args):
denoised = args['denoised']
denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised))
return denoised
unet.add_sampler_pre_cfg_function(pre_cfg)
unet.set_model_sampler_post_cfg_function(post_cfg)
process.sd_model.forge_objects.unet = unet
self.latent = latent_image
mixed_cond = cond * (1.0 - mask) - mask
return mixed_cond, None
def process_after_every_sampling(self, process, params, *args, **kwargs):
a1111_batch_result = args[0]
new_results = []
for img in a1111_batch_result.images:
sigma = 7
mask = self.mask[0, 0].detach().cpu().numpy().astype(np.float32)
mask = cv2.dilate(mask, np.ones((sigma, sigma), dtype=np.uint8))
mask = cv2.blur(mask, (sigma, sigma))[None]
mask = torch.from_numpy(np.ascontiguousarray(mask).copy()).to(img).clip(0, 1)
raw = self.image[0].to(img).clip(0, 1)
img = img.clip(0, 1)
new_results.append(raw * (1.0 - mask) + img * mask)
a1111_batch_result.images = new_results
return
class PreprocessorInpaintLama(PreprocessorInpaintOnly):
def __init__(self):
super().__init__()
self.name = 'inpaint_only+lama'
def load_model(self):
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetLama.pth"
model_path = load_file_from_url(remote_model_path, model_dir=preprocessor_dir)
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama_config.yaml')
cfg = yaml.safe_load(open(config_path, 'rt'))
cfg = OmegaConf.create(cfg)
cfg.training_model.predict_only = True
cfg.visualizer.kind = 'noop'
model = load_checkpoint(cfg, os.path.abspath(model_path), strict=False, map_location='cpu')
self.setup_model_patcher(model)
return
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, input_mask=None, **kwargs):
if input_mask is None:
return input_image
H, W, C = input_image.shape
raw_color = input_image.copy()
raw_mask = input_mask.copy()
input_image, remove_pad = resize_image_with_pad(input_image, 256)
input_mask, remove_pad = resize_image_with_pad(input_mask, 256)
input_mask = input_mask[..., :1]
self.load_model()
self.move_all_model_patchers_to_gpu()
color = np.ascontiguousarray(input_image).astype(np.float32) / 255.0
mask = np.ascontiguousarray(input_mask).astype(np.float32) / 255.0
with torch.no_grad():
color = self.send_tensor_to_model_device(torch.from_numpy(color))
mask = self.send_tensor_to_model_device(torch.from_numpy(mask))
mask = (mask > 0.5).float()
color = color * (1 - mask)
image_feed = torch.cat([color, mask], dim=2)
image_feed = einops.rearrange(image_feed, 'h w c -> 1 c h w')
prd_color = self.model_patcher.model(image_feed)[0]
prd_color = einops.rearrange(prd_color, 'c h w -> h w c')
prd_color = prd_color * mask + color * (1 - mask)
prd_color *= 255.0
prd_color = prd_color.detach().cpu().numpy().clip(0, 255).astype(np.uint8)
prd_color = remove_pad(prd_color)
prd_color = cv2.resize(prd_color, (W, H))
alpha = raw_mask.astype(np.float32) / 255.0
fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha)
fin_color = fin_color.clip(0, 255).astype(np.uint8)
return fin_color
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
cond, mask = super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
sigma_max = process.sd_model.forge_objects.unet.model.predictor.sigma_max
original_noise = kwargs['noise']
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
return cond, mask
add_supported_preprocessor(PreprocessorInpaint())
add_supported_preprocessor(PreprocessorInpaintOnly())
add_supported_preprocessor(PreprocessorInpaintLama())