mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-01-26 15:29:46 +00:00
88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import random
|
|
import math
|
|
import numpy as np
|
|
|
|
|
|
class MaskingGenerator:
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
num_masking_patches=None,
|
|
min_num_patches=4,
|
|
max_num_patches=None,
|
|
min_aspect=0.3,
|
|
max_aspect=None,
|
|
):
|
|
if not isinstance(input_size, tuple):
|
|
input_size = (input_size,) * 2
|
|
self.height, self.width = input_size
|
|
|
|
self.num_patches = self.height * self.width
|
|
self.num_masking_patches = num_masking_patches
|
|
|
|
self.min_num_patches = min_num_patches
|
|
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
|
|
|
|
max_aspect = max_aspect or 1 / min_aspect
|
|
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
|
|
|
def __repr__(self):
|
|
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
|
self.height,
|
|
self.width,
|
|
self.min_num_patches,
|
|
self.max_num_patches,
|
|
self.num_masking_patches,
|
|
self.log_aspect_ratio[0],
|
|
self.log_aspect_ratio[1],
|
|
)
|
|
return repr_str
|
|
|
|
def get_shape(self):
|
|
return self.height, self.width
|
|
|
|
def _mask(self, mask, max_mask_patches):
|
|
delta = 0
|
|
for _ in range(10):
|
|
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
|
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
|
if w < self.width and h < self.height:
|
|
top = random.randint(0, self.height - h)
|
|
left = random.randint(0, self.width - w)
|
|
|
|
num_masked = mask[top : top + h, left : left + w].sum()
|
|
# Overlap
|
|
if 0 < h * w - num_masked <= max_mask_patches:
|
|
for i in range(top, top + h):
|
|
for j in range(left, left + w):
|
|
if mask[i, j] == 0:
|
|
mask[i, j] = 1
|
|
delta += 1
|
|
|
|
if delta > 0:
|
|
break
|
|
return delta
|
|
|
|
def __call__(self, num_masking_patches=0):
|
|
mask = np.zeros(shape=self.get_shape(), dtype=bool)
|
|
mask_count = 0
|
|
while mask_count < num_masking_patches:
|
|
max_mask_patches = num_masking_patches - mask_count
|
|
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
|
|
|
delta = self._mask(mask, max_mask_patches)
|
|
if delta == 0:
|
|
break
|
|
else:
|
|
mask_count += delta
|
|
|
|
return mask
|