marigold_ini

This commit is contained in:
lllyasviel
2024-02-02 00:14:37 -08:00
parent 24db0e241a
commit 7d7094b6ec
11 changed files with 703 additions and 1 deletions

View File

@@ -0,0 +1,313 @@
# Author: Bingxin Ke
# Last modified: 2023-12-11
import logging
from typing import Dict
import numpy as np
import torch
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
DEISMultistepScheduler,
SchedulerMixin,
UNet2DConditionModel,
)
from torch import nn
from torch.nn import Conv2d
from torch.nn.parameter import Parameter
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from .rgb_encoder import RGBEncoder
from .stacked_depth_AE import StackedDepthAE
class MarigoldPipeline(nn.Module):
"""
Marigold monocular depth estimator.
"""
def __init__(
self,
unet_pretrained_path: Dict, # {path: xxx, subfolder: xxx}
rgb_encoder_pretrained_path: Dict,
depht_ae_pretrained_path: Dict,
noise_scheduler_pretrained_path: Dict,
tokenizer_pretrained_path: Dict,
text_encoder_pretrained_path: Dict,
empty_text_embed=None,
trainable_unet=False,
rgb_latent_scale_factor=0.18215,
depth_latent_scale_factor=0.18215,
noise_scheduler_type=None,
enable_gradient_checkpointing=False,
enable_xformers=True,
) -> None:
super().__init__()
self.rgb_latent_scale_factor = rgb_latent_scale_factor
self.depth_latent_scale_factor = depth_latent_scale_factor
self.device = "cpu"
# ******* Initialize modules *******
# Trainable modules
self.trainable_module_dic: Dict[str, nn.Module] = {}
self.trainable_unet = trainable_unet
# Denoising UNet
self.unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
unet_pretrained_path["path"], subfolder=unet_pretrained_path["subfolder"]
)
logging.info(f"pretrained UNet loaded from: {unet_pretrained_path}")
if 8 != self.unet.config["in_channels"]:
self._replace_unet_conv_in()
logging.warning("Unet conv_in layer is replaced")
if enable_xformers:
self.unet.enable_xformers_memory_efficient_attention()
else:
self.unet.disable_xformers_memory_efficient_attention()
# Image encoder
self.rgb_encoder = RGBEncoder(
pretrained_path=rgb_encoder_pretrained_path["path"],
subfolder=rgb_encoder_pretrained_path["subfolder"],
)
logging.info(
f"pretrained RGBEncoder loaded from: {rgb_encoder_pretrained_path}"
)
self.rgb_encoder.requires_grad_(False)
# Depth encoder-decoder
self.depth_ae = StackedDepthAE(
pretrained_path=depht_ae_pretrained_path["path"],
subfolder=depht_ae_pretrained_path["subfolder"],
)
logging.info(
f"pretrained Depth Autoencoder loaded from: {rgb_encoder_pretrained_path}"
)
# Trainability
# unet
if self.trainable_unet:
self.unet.requires_grad_(True)
self.trainable_module_dic["unet"] = self.unet
logging.debug(f"UNet is set to trainable")
else:
self.unet.requires_grad_(False)
logging.debug(f"UNet is set to frozen")
# Gradient checkpointing
if enable_gradient_checkpointing:
self.unet.enable_gradient_checkpointing()
self.depth_ae.vae.enable_gradient_checkpointing()
# Noise scheduler
if "DDPMScheduler" == noise_scheduler_type:
self.noise_scheduler: SchedulerMixin = DDPMScheduler.from_pretrained(
noise_scheduler_pretrained_path["path"],
subfolder=noise_scheduler_pretrained_path["subfolder"],
)
elif "DDIMScheduler" == noise_scheduler_type:
self.noise_scheduler: SchedulerMixin = DDIMScheduler.from_pretrained(
noise_scheduler_pretrained_path["path"],
subfolder=noise_scheduler_pretrained_path["subfolder"],
)
elif "PNDMScheduler" == noise_scheduler_type:
self.noise_scheduler: SchedulerMixin = PNDMScheduler.from_pretrained(
noise_scheduler_pretrained_path["path"],
subfolder=noise_scheduler_pretrained_path["subfolder"],
)
elif "DEISMultistepScheduler" == noise_scheduler_type:
self.noise_scheduler: SchedulerMixin = DEISMultistepScheduler.from_pretrained(
noise_scheduler_pretrained_path["path"],
subfolder=noise_scheduler_pretrained_path["subfolder"],
)
else:
raise NotImplementedError
# Text embed for empty prompt (always in CPU)
if empty_text_embed is None:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
tokenizer_pretrained_path["path"],
subfolder=tokenizer_pretrained_path["subfolder"],
)
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(
text_encoder_pretrained_path["path"],
subfolder=text_encoder_pretrained_path["subfolder"],
)
with torch.no_grad():
self.empty_text_embed = self._encode_text(
"", tokenizer, text_encoder
).detach()#.to(dtype=precision) # [1, 2, 1024]
else:
self.empty_text_embed = empty_text_embed
def from_pretrained(pretrained_path, **kwargs):
return __class__(
unet_pretrained_path={"path": pretrained_path, "subfolder": "unet"},
rgb_encoder_pretrained_path={"path": pretrained_path, "subfolder": "vae"},
depht_ae_pretrained_path={"path": pretrained_path, "subfolder": "vae"},
noise_scheduler_pretrained_path={
"path": pretrained_path,
"subfolder": "scheduler",
},
tokenizer_pretrained_path={
"path": pretrained_path,
"subfolder": "tokenizer",
},
text_encoder_pretrained_path={
"path": pretrained_path,
"subfolder": "text_encoder",
},
**kwargs,
)
def _replace_unet_conv_in(self):
# Replace the first layer to accept 8 in_channels. Only applied when loading pretrained SD U-Net
_weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3]
_bias = self.unet.conv_in.bias.clone() # [320]
_weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
# half the activation magnitude
_weight *= 0.5
_bias *= 0.5
# new conv_in channel
_n_convin_out_channel = self.unet.conv_in.out_channels
_new_conv_in = Conv2d(
8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
_new_conv_in.weight = Parameter(_weight)
_new_conv_in.bias = Parameter(_bias)
self.unet.conv_in = _new_conv_in
# replace config
self.unet.config["in_channels"] = 8
return
def to(self, device):
self.rgb_encoder.to(device)
self.depth_ae.to(device)
self.unet.to(device)
self.empty_text_embed = self.empty_text_embed.to(device)
self.device = device
return self
def forward(
self,
rgb_in,
num_inference_steps: int = 50,
num_output_inter_results: int = 0,
show_pbar=False,
init_depth_latent=None,
return_depth_latent=False,
):
device = rgb_in.device
precision = self.unet.dtype
# Set timesteps
self.noise_scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.noise_scheduler.timesteps # [T]
# Encode image
rgb_latent = self.encode_rgb(rgb_in)
# Initial depth map (noise)
if init_depth_latent is not None:
init_depth_latent = init_depth_latent.to(dtype=precision)
assert (
init_depth_latent.shape == rgb_latent.shape
), "initial depth latent should be the size of [B, 4, H/8, W/8]"
depth_latent = init_depth_latent
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=precision)
else:
depth_latent = torch.randn(rgb_latent.shape, device=device) # [B, 4, h, w]
# Expand text embeding for batch
batch_empty_text_embed = self.empty_text_embed.repeat(
(rgb_latent.shape[0], 1, 1)
).to(device=device, dtype=precision) # [B, 2, 1024]
# Export intermediate denoising steps
if num_output_inter_results > 0:
depth_latent_ls = []
inter_steps = []
_idx = (
-1
* (
np.arange(0, num_output_inter_results)
* num_inference_steps
/ num_output_inter_results
)
.round()
.astype(int)
- 1
)
steps_to_output = timesteps[_idx]
# Denoising loop
if show_pbar:
iterable = tqdm(enumerate(timesteps), total=len(timesteps), leave=False, desc="denoising")
else:
iterable = enumerate(timesteps)
for i, t in iterable:
unet_input = torch.cat(
[rgb_latent, depth_latent], dim=1
) # this order is important
unet_input = unet_input.to(dtype=precision)
# predict the noise residual
noise_pred = self.unet(
unet_input, t, encoder_hidden_states=batch_empty_text_embed
).sample # [B, 4, h, w]
# compute the previous noisy sample x_t -> x_t-1
depth_latent = self.noise_scheduler.step(
noise_pred, t, depth_latent
).prev_sample.to(dtype=precision)
if num_output_inter_results > 0 and t in steps_to_output:
depth_latent_ls.append(depth_latent.detach().clone())
#depth_latent_ls = depth_latent_ls.to(dtype=precision)
inter_steps.append(t - 1)
# Decode depth latent
if num_output_inter_results > 0:
assert 0 in inter_steps
depth = [self.decode_depth(lat) for lat in depth_latent_ls]
if return_depth_latent:
return depth, inter_steps, depth_latent_ls
else:
return depth, inter_steps
else:
depth = self.decode_depth(depth_latent)
if return_depth_latent:
return depth, depth_latent
else:
return depth
def encode_rgb(self, rgb_in):
rgb_latent = self.rgb_encoder(rgb_in) # [B, 4, h, w]
rgb_latent = rgb_latent * self.rgb_latent_scale_factor
return rgb_latent
def encode_depth(self, depth_in):
depth_latent = self.depth_ae.encode(depth_in)
depth_latent = depth_latent * self.depth_latent_scale_factor
return depth_latent
def decode_depth(self, depth_latent):
#depth_latent = depth_latent.to(dtype=torch.float16)
depth_latent = depth_latent / self.depth_latent_scale_factor
depth = self.depth_ae.decode(depth_latent) # [B, 1, H, W]
return depth
@staticmethod
def _encode_text(prompt, tokenizer, text_encoder):
text_inputs = tokenizer(
prompt,
padding="do_not_pad",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(text_encoder.device)
text_embed = text_encoder(text_input_ids)[0]
return text_embed

View File

@@ -0,0 +1,36 @@
# Author: Bingxin Ke
# Last modified: 2023-12-05
import torch
import torch.nn as nn
import logging
from diffusers import AutoencoderKL
class RGBEncoder(nn.Module):
"""
The encoder of pretrained Stable Diffusion VAE
"""
def __init__(self, pretrained_path, subfolder=None) -> None:
super().__init__()
vae: AutoencoderKL = AutoencoderKL.from_pretrained(pretrained_path, subfolder=subfolder)
logging.info(f"pretrained AutoencoderKL loaded from: {pretrained_path}")
self.rgb_encoder = nn.Sequential(
vae.encoder,
vae.quant_conv,
)
def to(self, *args, **kwargs):
self.rgb_encoder.to(*args, **kwargs)
def forward(self, rgb_in):
return self.encode(rgb_in)
def encode(self, rgb_in):
moments = self.rgb_encoder(rgb_in) # [B, 8, H/8, W/8]
mean, logvar = torch.chunk(moments, 2, dim=1)
rgb_latent = mean
return rgb_latent

View File

@@ -0,0 +1,52 @@
# Author: Bingxin Ke
# Last modified: 2023-12-05
import torch
import torch.nn as nn
import logging
from diffusers import AutoencoderKL
class StackedDepthAE(nn.Module):
"""
Tailored pretrained image VAE for depth map.
Encode: Depth images are repeated into 3 channels.
Decode: The average of 3 chennels are taken as output.
"""
def __init__(self, pretrained_path, subfolder=None) -> None:
super().__init__()
self.vae: AutoencoderKL = AutoencoderKL.from_pretrained(pretrained_path, subfolder=subfolder)
logging.info(f"pretrained AutoencoderKL loaded from: {pretrained_path}")
def forward(self, depth_in):
depth_latent = self.encode(depth_in)
depth_out = self.decode(depth_latent)
return depth_out
def to(self, *args, **kwargs):
self.vae.to(*args, **kwargs)
@staticmethod
def _stack_depth_images(depth_in):
if 4 == len(depth_in.shape):
stacked = depth_in.repeat(1, 3, 1, 1)
elif 3 == len(depth_in.shape):
stacked = depth_in.unsqueeze(1)
stacked = depth_in.repeat(1, 3, 1, 1)
return stacked
def encode(self, depth_in):
stacked = self._stack_depth_images(depth_in)
h = self.vae.encoder(stacked)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
depth_latent = mean
return depth_latent
def decode(self, depth_latent):
z = self.vae.post_quant_conv(depth_latent)
stacked = self.vae.decoder(z)
depth_mean = stacked.mean(dim=1, keepdim=True)
return depth_mean