initial commit

This commit is contained in:
2025-10-21 13:37:07 +07:00
commit 9cd16e276a
1574 changed files with 2675557 additions and 0 deletions

View File

@@ -0,0 +1,192 @@
import spaces
import functools
import os
import shutil
import sys
import git
import gradio as gr
import numpy as np
import torch as torch
from PIL import Image
from gradio_imageslider import ImageSlider
import spaces
import argparse
import os
import logging
import numpy as np
import torch
from PIL import Image
from tqdm.auto import tqdm
import glob
import json
import cv2
import sys
from geo_models.geowizard_pipeline import DepthNormalEstimationPipeline
from geo_utils.seed_all import seed_all
import matplotlib.pyplot as plt
from geo_utils.de_normalized import align_scale_shift
from geo_utils.depth2normal import *
from diffusers import DiffusionPipeline, DDIMScheduler, AutoencoderKL
from geo_models.unet_2d_condition import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
device = spaces.gpu
with spaces.capture_gpu_object() as gpu_object:
vae = AutoencoderKL.from_pretrained(spaces.convert_root_path(), subfolder='vae')
scheduler = DDIMScheduler.from_pretrained(spaces.convert_root_path(), subfolder='scheduler')
image_encoder = CLIPVisionModelWithProjection.from_pretrained(spaces.convert_root_path(), subfolder="image_encoder")
feature_extractor = CLIPImageProcessor.from_pretrained(spaces.convert_root_path(), subfolder="feature_extractor")
unet = UNet2DConditionModel.from_pretrained(spaces.convert_root_path(), subfolder="unet")
pipe = DepthNormalEstimationPipeline(vae=vae,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
unet=unet,
scheduler=scheduler)
outputs_dir = "./outputs"
spaces.automatically_move_pipeline_components(pipe)
spaces.automatically_move_to_gpu_when_forward(pipe.vae.encoder, target_model=pipe.vae)
spaces.automatically_move_to_gpu_when_forward(pipe.vae.decoder, target_model=pipe.vae)
spaces.automatically_move_to_gpu_when_forward(pipe.vae.post_quant_conv, target_model=pipe.vae)
# spaces.change_attention_from_diffusers_to_forge(vae)
# spaces.change_attention_from_diffusers_to_forge(unet)
# pipe = pipe.to(device)
@spaces.GPU(gpu_objects=gpu_object, manual_load=True)
def depth_normal(img,
denoising_steps,
ensemble_size,
processing_res,
seed,
domain):
seed = int(seed)
if seed >= 0:
torch.manual_seed(seed)
pipe_out = pipe(
img,
denoising_steps=denoising_steps,
ensemble_size=ensemble_size,
processing_res=processing_res,
batch_size=0,
domain=domain,
show_progress_bar=True,
)
depth_colored = Image.fromarray(((1. - pipe_out.depth_np) * 255.0).clip(0, 255).astype(np.uint8))
normal_colored = pipe_out.normal_colored
return depth_colored, normal_colored
def run_demo():
custom_theme = gr.themes.Soft(primary_hue="blue").set(
button_secondary_background_fill="*neutral_100",
button_secondary_background_fill_hover="*neutral_200")
custom_css = '''#disp_image {
text-align: center; /* Horizontally center the content */
}'''
_TITLE = '''GeoWizard: Unleashing the Diffusion Priors for 3D Geometry Estimation from a Single Image'''
_DESCRIPTION = '''
<div>
Generate consistent depth and normal from single image. High quality and rich details. (PS: We find the demo running on ZeroGPU output slightly inferior results compared to A100 or 3060 with everything exactly the same.)
<a style="display:inline-block; margin-left: .5em" href='https://github.com/fuxiao0719/GeoWizard/'><img src='https://img.shields.io/github/stars/fuxiao0719/GeoWizard?style=social' /></a>
</div>
'''
_GPU_ID = 0
with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image')
example_folder = os.path.join(spaces.convert_root_path(), "files")
example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
gr.Examples(
examples=example_fns,
inputs=[input_image],
cache_examples=False,
label='Examples (click one of the images below to start)',
examples_per_page=30
)
with gr.Column(scale=1):
with gr.Accordion('Advanced options', open=True):
with gr.Column():
domain = gr.Radio(
[
("Outdoor", "outdoor"),
("Indoor", "indoor"),
("Object", "object"),
],
label="Data Type (Must Select One matches your image)",
value="indoor",
)
denoising_steps = gr.Slider(
label="Number of denoising steps (More steps, better quality)",
minimum=1,
maximum=50,
step=1,
value=10,
)
ensemble_size = gr.Slider(
label="Ensemble size (More steps, higher accuracy)",
minimum=1,
maximum=15,
step=1,
value=3,
)
seed = gr.Number(0, label='Random Seed. Negative values for not specifying')
processing_res = gr.Radio(
[
("Native", 0),
("Recommended", 768),
],
label="Processing resolution",
value=768,
)
run_btn = gr.Button('Generate', variant='primary', interactive=True)
with gr.Row():
with gr.Column():
depth = gr.Image(interactive=False, show_label=False)
with gr.Column():
normal = gr.Image(interactive=False, show_label=False)
run_btn.click(fn=depth_normal,
inputs=[input_image, denoising_steps,
ensemble_size,
processing_res,
seed,
domain],
outputs=[depth, normal]
)
return demo
demo = run_demo()
if __name__ == '__main__':
demo.queue().launch(share=True, max_threads=80)

View File

@@ -0,0 +1,684 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Some modifications are reimplemented in public environments by Xiao Fu and Mu Hu
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
# import xformers
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
from diffusers.models.attention_processor import Attention
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
def _chunked_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
if lora_scale is None:
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
else:
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
A gated self-attention dense layer that combines visual features and object features.
Parameters:
query_dim (`int`): The number of channels in the query.
context_dim (`int`): The number of channels in the context.
n_heads (`int`): The number of heads to use for attention.
d_head (`int`): The number of channels in each head.
"""
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__()
# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
self.enabled = True
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x
n_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_continuous:
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = CustomJointAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if self.use_ada_layer_norm:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_continuous:
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if self.use_ada_layer_norm_continuous:
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if self.use_ada_layer_norm_continuous:
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class CustomJointAttention(Attention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
from backend.attention import AttentionProcessorForge
self.set_processor(AttentionProcessorForge())
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class SkipFFTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
kv_input_dim: int,
kv_input_dim_proj_use_bias: bool,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
):
super().__init__()
if kv_input_dim != dim:
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
else:
self.kv_mapper = None
self.norm1 = RMSNorm(dim, 1e-06)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
out_bias=attention_out_bias,
)
self.norm2 = RMSNorm(dim, 1e-06)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
if self.kv_mapper is not None:
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
for module in self.net:
if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
return hidden_states

View File

@@ -0,0 +1,370 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
from typing import Any, Dict, Union
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from tqdm.auto import tqdm
from PIL import Image
from diffusers import (
DiffusionPipeline,
DDIMScheduler,
AutoencoderKL,
)
from geo_models.unet_2d_condition import UNet2DConditionModel
from diffusers.utils import BaseOutput
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
from geo_utils.image_util import resize_max_res,chw2hwc,colorize_depth_maps
from geo_utils.colormap import kitti_colormap
from geo_utils.depth_ensemble import ensemble_depths
from geo_utils.normal_ensemble import ensemble_normals
from geo_utils.batch_size import find_batch_size
import cv2
class DepthNormalPipelineOutput(BaseOutput):
"""
Output class for Marigold monocular depth prediction pipeline.
Args:
depth_np (`np.ndarray`):
Predicted depth map, with depth values in the range of [0, 1].
depth_colored (`PIL.Image.Image`):
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
normal_np (`np.ndarray`):
Predicted normal map, with depth values in the range of [0, 1].
normal_colored (`PIL.Image.Image`):
Colorized normal map, with the shape of [3, H, W] and values in [0, 1].
uncertainty (`None` or `np.ndarray`):
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
"""
depth_np: np.ndarray
depth_colored: Image.Image
normal_np: np.ndarray
normal_colored: Image.Image
uncertainty: Union[None, np.ndarray]
class DepthNormalEstimationPipeline(DiffusionPipeline):
# two hyper-parameters
latent_scale_factor = 0.18215
def __init__(self,
unet:UNet2DConditionModel,
vae:AutoencoderKL,
scheduler:DDIMScheduler,
image_encoder:CLIPVisionModelWithProjection,
feature_extractor:CLIPImageProcessor,
):
super().__init__()
self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.img_embed = None
@torch.no_grad()
def __call__(self,
input_image:Image,
denoising_steps: int = 10,
ensemble_size: int = 10,
processing_res: int = 768,
match_input_res:bool =True,
batch_size:int = 0,
domain: str = "indoor",
color_map: str="Spectral",
show_progress_bar:bool = True,
ensemble_kwargs: Dict = None,
) -> DepthNormalPipelineOutput:
# inherit from thea Diffusion Pipeline
device = self.device
input_size = input_image.size
# adjust the input resolution.
if not match_input_res:
assert (
processing_res is not None
)," Value Error: `resize_output_back` is only valid with "
assert processing_res >=0
assert denoising_steps >=1
assert ensemble_size >=1
# --------------- Image Processing ------------------------
# Resize image
if processing_res >0:
input_image = resize_max_res(
input_image, max_edge_resolution=processing_res
)
# Convert the image to RGB, to 1. reomve the alpha channel.
input_image = input_image.convert("RGB")
image = np.array(input_image)
# Normalize RGB Values.
rgb = np.transpose(image,(2,0,1))
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
rgb_norm = rgb_norm.to(device)
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
# ----------------- predicting depth -----------------
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
single_rgb_dataset = TensorDataset(duplicated_rgb)
# find the batch size
if batch_size>0:
_bs = batch_size
else:
_bs = 1
single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False)
# predicted the depth
depth_pred_ls = []
normal_pred_ls = []
if show_progress_bar:
iterable_bar = tqdm(
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
)
else:
iterable_bar = single_rgb_loader
for batch in iterable_bar:
(batched_image, )= batch # here the image is still around 0-1
depth_pred_raw, normal_pred_raw = self.single_infer(
input_rgb=batched_image,
num_inference_steps=denoising_steps,
domain=domain,
show_pbar=show_progress_bar,
)
depth_pred_ls.append(depth_pred_raw.detach().clone())
normal_pred_ls.append(normal_pred_raw.detach().clone())
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() #(10,224,768)
normal_preds = torch.concat(normal_pred_ls, axis=0).squeeze()
torch.cuda.empty_cache() # clear vram cache for ensembling
# ----------------- Test-time ensembling -----------------
if ensemble_size > 1:
depth_pred, pred_uncert = ensemble_depths(
depth_preds, **(ensemble_kwargs or {})
)
normal_pred = ensemble_normals(normal_preds)
else:
depth_pred = depth_preds
normal_pred = normal_preds
pred_uncert = None
# ----------------- Post processing -----------------
# Scale prediction to [0, 1]
min_d = torch.min(depth_pred)
max_d = torch.max(depth_pred)
depth_pred = (depth_pred - min_d) / (max_d - min_d)
# Convert to numpy
depth_pred = depth_pred.cpu().numpy().astype(np.float32)
normal_pred = normal_pred.cpu().numpy().astype(np.float32)
# Resize back to original resolution
if match_input_res:
pred_img = Image.fromarray(depth_pred)
pred_img = pred_img.resize(input_size)
depth_pred = np.asarray(pred_img)
normal_pred = cv2.resize(chw2hwc(normal_pred), input_size, interpolation = cv2.INTER_NEAREST)
# Clip output range: current size is the original size
depth_pred = depth_pred.clip(0, 1)
normal_pred = normal_pred.clip(-1, 1)
# Colorize
depth_colored = colorize_depth_maps(
depth_pred, 0, 1, cmap=color_map
).squeeze() # [3, H, W], value in (0, 1)
depth_colored = (depth_colored * 255).astype(np.uint8)
depth_colored_hwc = chw2hwc(depth_colored)
depth_colored_img = Image.fromarray(depth_colored_hwc)
normal_colored = ((normal_pred + 1)/2 * 255).astype(np.uint8)
normal_colored_img = Image.fromarray(normal_colored)
self.img_embed = None
return DepthNormalPipelineOutput(
depth_np = depth_pred,
depth_colored = depth_colored_img,
normal_np = normal_pred,
normal_colored = normal_colored_img,
uncertainty=pred_uncert,
)
def __encode_img_embed(self, rgb):
"""
Encode clip embeddings for img
"""
clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device=self.device, dtype=self.dtype)
clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device=self.device, dtype=self.dtype)
img_in_proc = TF.resize((rgb +1)/2,
(self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']),
interpolation=InterpolationMode.BICUBIC,
antialias=True
)
# do the normalization in float32 to preserve precision
img_in_proc = ((img_in_proc.float() - clip_image_mean) / clip_image_std).to(self.dtype)
img_embed = self.image_encoder(img_in_proc).image_embeds.unsqueeze(1).to(self.dtype)
self.img_embed = img_embed
@torch.no_grad()
def single_infer(self,input_rgb:torch.Tensor,
num_inference_steps:int,
domain:str,
show_pbar:bool,):
device = input_rgb.device
# Set timesteps: inherit from the diffuison pipeline
self.scheduler.set_timesteps(num_inference_steps, device=device) # here the numbers of the steps is only 10.
timesteps = self.scheduler.timesteps # [T]
# encode image
rgb_latent = self.encode_RGB(input_rgb)
# Initial geometric maps (Guassian noise)
geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
rgb_latent = rgb_latent.repeat(2,1,1,1)
# Batched img embedding
if self.img_embed is None:
self.__encode_img_embed(input_rgb)
batch_img_embed = self.img_embed.repeat(
(rgb_latent.shape[0], 1, 1)
) # [B, 1, 768]
# hybrid switcher
geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype)
geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1)
if domain == "indoor":
domain_class = torch.tensor([[1., 0., 0]], device=device, dtype=self.dtype).repeat(2,1)
elif domain == "outdoor":
domain_class = torch.tensor([[0., 1., 0]], device=device, dtype=self.dtype).repeat(2,1)
elif domain == "object":
domain_class = torch.tensor([[0., 0., 1]], device=device, dtype=self.dtype).repeat(2,1)
domain_embedding = torch.cat([torch.sin(domain_class), torch.cos(domain_class)], dim=-1)
class_embedding = torch.cat((geo_embedding, domain_embedding), dim=-1)
# Denoising loop
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
unet_input = torch.cat([rgb_latent, geo_latent], dim=1)
# predict the noise residual
noise_pred = self.unet(
unet_input, t.repeat(2), encoder_hidden_states=batch_img_embed, class_labels=class_embedding
).sample # [B, 4, h, w]
# compute the previous noisy sample x_t -> x_t-1
geo_latent = self.scheduler.step(noise_pred, t, geo_latent).prev_sample
geo_latent = geo_latent
torch.cuda.empty_cache()
depth = self.decode_depth(geo_latent[0][None])
depth = torch.clip(depth, -1.0, 1.0)
depth = (depth + 1.0) / 2.0
normal = self.decode_normal(geo_latent[1][None])
normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5)
normal *= -1.
return depth, normal
def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (`torch.Tensor`):
Input RGB image to be encoded.
Returns:
`torch.Tensor`: Image latent.
"""
# encode
h = self.vae.encoder(rgb_in)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
# scale latent
rgb_latent = mean * self.latent_scale_factor
return rgb_latent
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
"""
Decode depth latent into depth map.
Args:
depth_latent (`torch.Tensor`):
Depth latent to be decoded.
Returns:
`torch.Tensor`: Decoded depth map.
"""
# scale latent
depth_latent = depth_latent / self.latent_scale_factor
# decode
z = self.vae.post_quant_conv(depth_latent)
stacked = self.vae.decoder(z)
# mean of output channels
depth_mean = stacked.mean(dim=1, keepdim=True)
return depth_mean
def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor:
"""
Decode normal latent into normal map.
Args:
normal_latent (`torch.Tensor`):
Depth latent to be decoded.
Returns:
`torch.Tensor`: Decoded normal map.
"""
# scale latent
normal_latent = normal_latent / self.latent_scale_factor
# decode
z = self.vae.post_quant_conv(normal_latent)
normal = self.vae.decoder(z)
return normal

View File

@@ -0,0 +1,463 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Some modifications are reimplemented in public environments by Xiao Fu and Mu Hu
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.embeddings import ImagePositionalEmbeddings
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from geo_models.attention import BasicTransformerBlock
from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormSingle
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
A 2D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
" sure that either `num_vector_embeds` or `num_patches` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = linear_cls(in_channels, inner_dim)
else:
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
elif self.is_input_patches:
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
interpolation_scale = max(interpolation_scale, 1)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = linear_cls(inner_dim, in_channels)
else:
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches and norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
elif self.is_input_patches and norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
# 5. PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else:
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,63 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import torch
import math
# Search table for suggested max. inference batch size
bs_search_table = [
# tested on A100-PCIE-80GB
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
# tested on A100-PCIE-40GB
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
# tested on RTX3090, RTX4090
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
# tested on GTX1080Ti
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
]
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
"""
Automatically search for suitable operating batch size.
Args:
ensemble_size (`int`):
Number of predictions to be ensembled.
input_res (`int`):
Operating resolution of the input image.
Returns:
`int`: Operating batch size.
"""
if not torch.cuda.is_available():
return 1
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
for settings in sorted(
filtered_bs_search_table,
key=lambda k: (k["res"], -k["total_vram"]),
):
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
bs = settings["bs"]
if bs > ensemble_size:
bs = ensemble_size
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
bs = math.ceil(ensemble_size / 2)
return bs
return 1

View File

@@ -0,0 +1,45 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import numpy as np
import cv2
def kitti_colormap(disparity, maxval=-1):
"""
A utility function to reproduce KITTI fake colormap
Arguments:
- disparity: numpy float32 array of dimension HxW
- maxval: maximum disparity value for normalization (if equal to -1, the maximum value in disparity will be used)
Returns a numpy uint8 array of shape HxWx3.
"""
if maxval < 0:
maxval = np.max(disparity)
colormap = np.asarray([[0,0,0,114],[0,0,1,185],[1,0,0,114],[1,0,1,174],[0,1,0,114],[0,1,1,185],[1,1,0,114],[1,1,1,0]])
weights = np.asarray([8.771929824561404,5.405405405405405,8.771929824561404,5.747126436781609,8.771929824561404,5.405405405405405,8.771929824561404,0])
cumsum = np.asarray([0,0.114,0.299,0.413,0.587,0.701,0.8859999999999999,0.9999999999999999])
colored_disp = np.zeros([disparity.shape[0], disparity.shape[1], 3])
values = np.expand_dims(np.minimum(np.maximum(disparity/maxval, 0.), 1.), -1)
bins = np.repeat(np.repeat(np.expand_dims(np.expand_dims(cumsum,axis=0),axis=0), disparity.shape[1], axis=1), disparity.shape[0], axis=0)
diffs = np.where((np.repeat(values, 8, axis=-1) - bins) > 0, -1000, (np.repeat(values, 8, axis=-1) - bins))
index = np.argmax(diffs, axis=-1)-1
w = 1-(values[:,:,0]-cumsum[index])*np.asarray(weights)[index]
colored_disp[:,:,2] = (w*colormap[index][:,:,0] + (1.-w)*colormap[index+1][:,:,0])
colored_disp[:,:,1] = (w*colormap[index][:,:,1] + (1.-w)*colormap[index+1][:,:,1])
colored_disp[:,:,0] = (w*colormap[index][:,:,2] + (1.-w)*colormap[index+1][:,:,2])
return (colored_disp*np.expand_dims((disparity>0),-1)*255).astype(np.uint8)
def read_16bit_gt(path):
"""
A utility function to read KITTI 16bit gt
Arguments:
- path: filepath
Returns a numpy float32 array of shape HxW.
"""
gt = cv2.imread(path,-1).astype(np.float32)/256.
return gt

View File

@@ -0,0 +1,42 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import json
import yaml
import logging
import os
import numpy as np
import sys
def load_loss_scheme(loss_config):
with open(loss_config, 'r') as f:
loss_json = yaml.safe_load(f)
return loss_json
DEBUG =0
logger = logging.getLogger()
if DEBUG:
#coloredlogs.install(level='DEBUG')
logger.setLevel(logging.DEBUG)
else:
#coloredlogs.install(level='INFO')
logger.setLevel(logging.INFO)
strhdlr = logging.StreamHandler()
logger.addHandler(strhdlr)
formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s')
strhdlr.setFormatter(formatter)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def check_path(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)

View File

@@ -0,0 +1,81 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
sys.path.append("..")
from dataloader.mix_loader import MixDataset
from torch.utils.data import DataLoader
from dataloader import transforms
import os
# Get Dataset Here
def prepare_dataset(data_dir=None,
batch_size=1,
test_batch=1,
datathread=4,
logger=None):
# set the config parameters
dataset_config_dict = dict()
train_dataset = MixDataset(data_dir=data_dir)
img_height, img_width = train_dataset.get_img_size()
datathread = datathread
if os.environ.get('datathread') is not None:
datathread = int(os.environ.get('datathread'))
if logger is not None:
logger.info("Use %d processes to load data..." % datathread)
train_loader = DataLoader(train_dataset, batch_size = batch_size, \
shuffle = True, num_workers = datathread, \
pin_memory = True)
num_batches_per_epoch = len(train_loader)
dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch
dataset_config_dict['img_size'] = (img_height,img_width)
return train_loader, dataset_config_dict
def depth_scale_shift_normalization(depth):
bsz = depth.shape[0]
depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy()
min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None]
max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None]
normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2
normalized_depth = torch.clip(normalized_depth, -1., 1.)
return normalized_depth
def resize_max_res_tensor(input_tensor, mode, recom_resolution=768):
assert input_tensor.shape[1]==3
original_H, original_W = input_tensor.shape[2:]
downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W)
if mode == 'normal':
resized_input_tensor = F.interpolate(input_tensor,
scale_factor=downscale_factor,
mode='nearest')
else:
resized_input_tensor = F.interpolate(input_tensor,
scale_factor=downscale_factor,
mode='bilinear',
align_corners=False)
if mode == 'depth':
return resized_input_tensor / downscale_factor
else:
return resized_input_tensor

View File

@@ -0,0 +1,33 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import numpy as np
from scipy.optimize import least_squares
import torch
def align_scale_shift(pred, target, clip_max):
mask = (target > 0) & (target < clip_max)
if mask.sum() > 10:
target_mask = target[mask]
pred_mask = pred[mask]
scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
return scale, shift
else:
return 1, 0
def align_scale(pred: torch.tensor, target: torch.tensor):
mask = target > 0
if torch.sum(mask) > 10:
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
else:
scale = 1
pred_scale = pred * scale
return pred_scale, scale
def align_shift(pred: torch.tensor, target: torch.tensor):
mask = target > 0
if torch.sum(mask) > 10:
shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8)
else:
shift = 0
pred_shift = pred + shift
return pred_shift, shift

View File

@@ -0,0 +1,186 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import pickle
import os
# import h5py
import numpy as np
import cv2
import torch
import torch.nn as nn
import glob
def init_image_coor(height, width):
x_row = np.arange(0, width)
x = np.tile(x_row, (height, 1))
x = x[np.newaxis, :, :]
x = x.astype(np.float32)
x = torch.from_numpy(x.copy()).cuda()
u_u0 = x - width/2.0
y_col = np.arange(0, height) # y_col = np.arange(0, height)
y = np.tile(y_col, (width, 1)).T
y = y[np.newaxis, :, :]
y = y.astype(np.float32)
y = torch.from_numpy(y.copy()).cuda()
v_v0 = y - height/2.0
return u_u0, v_v0
def depth_to_xyz(depth, focal_length):
b, c, h, w = depth.shape
u_u0, v_v0 = init_image_coor(h, w)
x = u_u0 * depth / focal_length[0]
y = v_v0 * depth / focal_length[1]
z = depth
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
return pw
def get_surface_normal(xyz, patch_size=5):
# xyz: [1, h, w, 3]
x, y, z = torch.unbind(xyz, dim=3)
x = torch.unsqueeze(x, 0)
y = torch.unsqueeze(y, 0)
z = torch.unsqueeze(z, 0)
xx = x * x
yy = y * y
zz = z * z
xy = x * y
xz = x * z
yz = y * z
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
dim=4)
ATA = torch.squeeze(ATA)
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
ATA = ATA + eps_identity
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
AT1 = torch.squeeze(AT1)
AT1 = torch.unsqueeze(AT1, 3)
patch_num = 4
patch_x = int(AT1.size(1) / patch_num)
patch_y = int(AT1.size(0) / patch_num)
n_img = torch.randn(AT1.shape).cuda()
overlap = patch_size // 2 + 1
for x in range(int(patch_num)):
for y in range(int(patch_num)):
left_flg = 0 if x == 0 else 1
right_flg = 0 if x == patch_num -1 else 1
top_flg = 0 if y == 0 else 1
btm_flg = 0 if y == patch_num - 1 else 1
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
# n_img_tmp, _ = torch.solve(at1, ata)
n_img_tmp = torch.linalg.solve(ata, at1)
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
n_img_norm = n_img / n_img_L2
# re-orient normals consistently
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
n_img_norm[orient_mask] *= -1
return n_img_norm
def get_surface_normalv2(xyz, patch_size=5):
"""
xyz: xyz coordinates
patch: [p1, p2, p3,
p4, p5, p6,
p7, p8, p9]
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
return: normal [h, w, 3, b]
"""
b, h, w, c = xyz.shape
half_patch = patch_size // 2
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
xyz_horizon = xyz_left - xyz_right # p4p6
xyz_vertical = xyz_top - xyz_bottom # p2p8
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
# re-orient normals consistently
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
n_img_1[orient_mask] *= -1
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
n_img_2[orient_mask] *= -1
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
# average 2 norms
n_img_aver = n_img1_norm + n_img2_norm
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
# re-orient normals consistently
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
n_img_aver_norm[orient_mask] *= -1
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
# plt.imshow(np.abs(a), cmap='rainbow')
# plt.show()
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
# para depth: depth map, [b, c, h, w]
b, c, h, w = depth.shape
focal_length = focal_length[:, None, None, None]
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
#depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
xyz = depth_to_xyz(depth_filter, focal_length)
sn_batch = []
for i in range(b):
xyz_i = xyz[i, :][None, :, :, :]
#normal = get_surface_normalv2(xyz_i)
normal = get_surface_normal(xyz_i)
sn_batch.append(normal)
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
if valid_mask != None:
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
sn_batch[mask_invalid] = 0.0
return sn_batch

View File

@@ -0,0 +1,115 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import numpy as np
import torch
from scipy.optimize import minimize
def inter_distances(tensors: torch.Tensor):
"""
To calculate the distance between each two depth maps.
"""
distances = []
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
arr1 = tensors[i : i + 1]
arr2 = tensors[j : j + 1]
distances.append(arr1 - arr2)
dist = torch.concat(distances, dim=0)
return dist
def ensemble_depths(input_images:torch.Tensor,
regularizer_strength: float =0.02,
max_iter: int =2,
tol:float =1e-3,
reduction: str='median',
max_res: int=None):
"""
To ensemble multiple affine-invariant depth images (up to scale and shift),
by aligning estimating the scale and shift
"""
device = input_images.device
dtype = input_images.dtype
np_dtype = np.float32
original_input = input_images.clone()
n_img = input_images.shape[0]
ori_shape = input_images.shape
if max_res is not None:
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
if scale_factor < 1:
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
input_images = downscaler(torch.from_numpy(input_images)).numpy()
# init guess
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the min value of each possible depth
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the max value of each possible depth
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) #(10,1,1) : re-scale'f scale
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) #(10,1,1)
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) #(20,)
input_images = input_images.to(device)
# objective function
def closure(x):
l = len(x)
s = x[: int(l / 2)]
t = x[int(l / 2) :]
s = torch.from_numpy(s).to(dtype=dtype).to(device)
t = torch.from_numpy(t).to(dtype=dtype).to(device)
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
dists = inter_distances(transformed_arrays)
sqrt_dist = torch.sqrt(torch.mean(dists**2))
if "mean" == reduction:
pred = torch.mean(transformed_arrays, dim=0)
elif "median" == reduction:
pred = torch.median(transformed_arrays, dim=0).values
else:
raise ValueError
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
err = sqrt_dist + (near_err + far_err) * regularizer_strength
err = err.detach().cpu().numpy().astype(np_dtype)
return err
res = minimize(
closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
)
x = res.x
l = len(x)
s = x[: int(l / 2)]
t = x[int(l / 2) :]
# Prediction
s = torch.from_numpy(s).to(dtype=dtype).to(device)
t = torch.from_numpy(t).to(dtype=dtype).to(device)
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) #[10,H,W]
if "mean" == reduction:
aligned_images = torch.mean(transformed_arrays, dim=0)
std = torch.std(transformed_arrays, dim=0)
uncertainty = std
elif "median" == reduction:
aligned_images = torch.median(transformed_arrays, dim=0).values
# MAD (median absolute deviation) as uncertainty indicator
abs_dev = torch.abs(transformed_arrays - aligned_images)
mad = torch.median(abs_dev, dim=0).values
uncertainty = mad
# Scale and shift to [0, 1]
_min = torch.min(aligned_images)
_max = torch.max(aligned_images)
aligned_images = (aligned_images - _min) / (_max - _min)
uncertainty /= _max - _min
return aligned_images, uncertainty

View File

@@ -0,0 +1,83 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import matplotlib
import numpy as np
import torch
from PIL import Image
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`Image.Image`):
Image to be resized.
max_edge_resolution (`int`):
Maximum edge length (pixel).
Returns:
`Image.Image`: Resized image.
"""
original_width, original_height = img.size
downscale_factor = min(
max_edge_resolution / original_width, max_edge_resolution / original_height
)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
resized_img = img.resize((new_width, new_height))
return resized_img
def colorize_depth_maps(
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
):
"""
Colorize depth maps.
"""
assert len(depth_map.shape) >= 2, "Invalid dimension"
if isinstance(depth_map, torch.Tensor):
depth = depth_map.detach().clone().squeeze().numpy()
elif isinstance(depth_map, np.ndarray):
depth = depth_map.copy().squeeze()
# reshape to [ (B,) H, W ]
if depth.ndim < 3:
depth = depth[np.newaxis, :, :]
# colorize
cm = matplotlib.colormaps[cmap]
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
if valid_mask is not None:
if isinstance(depth_map, torch.Tensor):
valid_mask = valid_mask.detach().numpy()
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
if valid_mask.ndim < 3:
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
else:
valid_mask = valid_mask[:, np.newaxis, :, :]
valid_mask = np.repeat(valid_mask, 3, axis=1)
img_colored_np[~valid_mask] = 0
if isinstance(depth_map, torch.Tensor):
img_colored = torch.from_numpy(img_colored_np).float()
elif isinstance(depth_map, np.ndarray):
img_colored = img_colored_np
return img_colored
def chw2hwc(chw):
assert 3 == len(chw.shape)
if isinstance(chw, torch.Tensor):
hwc = torch.permute(chw, (1, 2, 0))
elif isinstance(chw, np.ndarray):
hwc = np.moveaxis(chw, 0, -1)
return hwc

View File

@@ -0,0 +1,22 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import numpy as np
import torch
def ensemble_normals(input_images:torch.Tensor):
normal_preds = input_images
bsz, d, h, w = normal_preds.shape
normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
normal_pred = torch.zeros((d,h,w)).to(normal_preds)
normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
normal_pred[2,:,:] = torch.cos(theta)
angle_error = torch.acos(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1))
normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
return normal_preds[normal_idx]

View File

@@ -0,0 +1,33 @@
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
import numpy as np
import random
import torch
def seed_all(seed: int = 0):
"""
Set random seeds of all components.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@@ -0,0 +1,213 @@
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import torch
import numpy as np
import torch.nn as nn
def init_image_coor(height, width):
x_row = np.arange(0, width)
x = np.tile(x_row, (height, 1))
x = x[np.newaxis, :, :]
x = x.astype(np.float32)
x = torch.from_numpy(x.copy()).cuda()
u_u0 = x - width/2.0
y_col = np.arange(0, height) # y_col = np.arange(0, height)
y = np.tile(y_col, (width, 1)).T
y = y[np.newaxis, :, :]
y = y.astype(np.float32)
y = torch.from_numpy(y.copy()).cuda()
v_v0 = y - height/2.0
return u_u0, v_v0
def depth_to_xyz(depth, focal_length):
b, c, h, w = depth.shape
u_u0, v_v0 = init_image_coor(h, w)
x = u_u0 * depth / focal_length
y = v_v0 * depth / focal_length
z = depth
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
return pw
def get_surface_normal(xyz, patch_size=3):
# xyz: [1, h, w, 3]
x, y, z = torch.unbind(xyz, dim=3)
x = torch.unsqueeze(x, 0)
y = torch.unsqueeze(y, 0)
z = torch.unsqueeze(z, 0)
xx = x * x
yy = y * y
zz = z * z
xy = x * y
xz = x * z
yz = y * z
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
dim=4)
ATA = torch.squeeze(ATA)
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
ATA = ATA + eps_identity
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
AT1 = torch.squeeze(AT1)
AT1 = torch.unsqueeze(AT1, 3)
patch_num = 4
patch_x = int(AT1.size(1) / patch_num)
patch_y = int(AT1.size(0) / patch_num)
n_img = torch.randn(AT1.shape).cuda()
overlap = patch_size // 2 + 1
for x in range(int(patch_num)):
for y in range(int(patch_num)):
left_flg = 0 if x == 0 else 1
right_flg = 0 if x == patch_num -1 else 1
top_flg = 0 if y == 0 else 1
btm_flg = 0 if y == patch_num - 1 else 1
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
n_img_tmp, _ = torch.solve(at1, ata)
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
n_img_norm = n_img / n_img_L2
# re-orient normals consistently
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
n_img_norm[orient_mask] *= -1
return n_img_norm
def get_surface_normalv2(xyz, patch_size=3):
"""
xyz: xyz coordinates
patch: [p1, p2, p3,
p4, p5, p6,
p7, p8, p9]
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
return: normal [h, w, 3, b]
"""
b, h, w, c = xyz.shape
half_patch = patch_size // 2
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
xyz_horizon = xyz_left - xyz_right # p4p6
xyz_vertical = xyz_top - xyz_bottom # p2p8
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
# re-orient normals consistently
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
n_img_1[orient_mask] *= -1
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
n_img_2[orient_mask] *= -1
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
# average 2 norms
n_img_aver = n_img1_norm + n_img2_norm
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
# re-orient normals consistently
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
n_img_aver_norm[orient_mask] *= -1
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
# plt.imshow(np.abs(a), cmap='rainbow')
# plt.show()
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
# para depth: depth map, [b, c, h, w]
b, c, h, w = depth.shape
focal_length = focal_length[:, None, None, None]
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
xyz = depth_to_xyz(depth_filter, focal_length)
sn_batch = []
for i in range(b):
xyz_i = xyz[i, :][None, :, :, :]
normal = get_surface_normalv2(xyz_i)
sn_batch.append(normal)
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
sn_batch[mask_invalid] = 0.0
return sn_batch
def vis_normal(normal):
"""
Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
@para normal: surface normal, [h, w, 3], numpy.array
"""
n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
n_img_norm = normal / (n_img_L2 + 1e-8)
normal_vis = n_img_norm * 127
normal_vis += 128
normal_vis = normal_vis.astype(np.uint8)
return normal_vis
def vis_normal2(normals):
'''
Montage of normal maps. Vectors are unit length and backfaces thresholded.
'''
x = normals[:, :, 0] # horizontal; pos right
y = normals[:, :, 1] # depth; pos far
z = normals[:, :, 2] # vertical; pos up
backfacing = (z > 0)
norm = np.sqrt(np.sum(normals**2, axis=2))
zero = (norm < 1e-5)
x += 1.0; x *= 0.5
y += 1.0; y *= 0.5
z = np.abs(z)
x[zero] = 0.0
y[zero] = 0.0
z[zero] = 0.0
normals[:, :, 0] = x # horizontal; pos right
normals[:, :, 1] = y # depth; pos far
normals[:, :, 2] = z # vertical; pos up
return normals
if __name__ == '__main__':
import cv2, os

View File

@@ -0,0 +1,6 @@
{
"tag": "Computer Vision: Depth, Normal, and Geometry",
"title": "GeoWizard: Unleashing the Diffusion Priors for 3D Geometry Estimation from a Single Image",
"repo_id": "lemonaddie/geowizard",
"revision": "e25e940c5c94c05be7ca84182a0aec7eb414edaa"
}