mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-06 21:49:57 +00:00
Bug fixes with ip adapter training. Made a clip pre processor that can be trained with ip adapter to help augment the clip input to squeeze in more detail from a larget input. moved clip processing to the dataloader for speed.
This commit is contained in:
@@ -666,7 +666,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
if was_unet_training:
|
||||
@@ -980,40 +980,45 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
image_size = self.adapter.input_size
|
||||
if is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
clip_images = torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
(noisy_latents.shape[0], 3, image_size, image_size),
|
||||
device=self.device_torch, dtype=dtype
|
||||
).detach()
|
||||
# drop will zero it out
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images,
|
||||
drop=True,
|
||||
is_training=True
|
||||
is_training=True,
|
||||
has_been_preprocessed=True
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
(noisy_latents.shape[0], 3, image_size, image_size),
|
||||
device=self.device_torch, dtype=dtype
|
||||
).detach(),
|
||||
is_training=True,
|
||||
drop=True
|
||||
drop=True,
|
||||
has_been_preprocessed=True
|
||||
)
|
||||
elif has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
is_training=True,
|
||||
has_been_preprocessed=True
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
(noisy_latents.shape[0], 3, image_size, image_size),
|
||||
device=self.device_torch, dtype=dtype
|
||||
).detach(),
|
||||
is_training=True,
|
||||
drop=True
|
||||
drop=True,
|
||||
has_been_preprocessed=True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
@@ -1094,7 +1099,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
@@ -678,6 +678,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.img_multiplier is not None:
|
||||
imgs = imgs * self.train_config.img_multiplier
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
@@ -1113,6 +1115,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||
|
||||
# load the adapters before the dataset as they may use the clip encoders
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
|
||||
### HOOk ###
|
||||
@@ -1249,7 +1254,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
# self.setup_adapter()
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
|
||||
@@ -272,6 +272,7 @@ class TrainConfig:
|
||||
|
||||
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
||||
self.do_cfg = kwargs.get('do_cfg', False)
|
||||
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
@@ -431,6 +431,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
for file in tqdm(file_list):
|
||||
try:
|
||||
file_item = FileItemDTO(
|
||||
sd=self.sd,
|
||||
path=file,
|
||||
dataset_config=dataset_config,
|
||||
dataloader_transforms=self.transform,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import weakref
|
||||
from _weakref import ReferenceType
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
import torch
|
||||
import random
|
||||
@@ -10,8 +12,10 @@ from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessing
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
printed_messages = []
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
@@ -27,7 +28,7 @@ from toolkit.train_tools import get_torch_dtype
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
# def get_associated_caption_from_img_path(img_path):
|
||||
# https://demo.albumentations.ai/
|
||||
@@ -565,8 +566,13 @@ class ClipImageFileItemDTOMixin:
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.has_clip_augmentations = False
|
||||
self.clip_image_aug_transform: Union[None, A.Compose] = None
|
||||
self.clip_image_processor: Union[None, CLIPImageProcessor] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.clip_image_path is not None:
|
||||
# copy the clip image processor so the dataloader can do it
|
||||
sd = kwargs.get('sd', None)
|
||||
if hasattr(sd.adapter, 'clip_image_processor'):
|
||||
self.clip_image_processor = sd.adapter.clip_image_processor
|
||||
# find the control image path
|
||||
clip_image_path = dataset_config.clip_image_path
|
||||
# we are using control images
|
||||
@@ -632,13 +638,22 @@ class ClipImageFileItemDTOMixin:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.clip_image_path}")
|
||||
|
||||
# we just scale them to 512x512:
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
if self.has_clip_augmentations:
|
||||
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
|
||||
else:
|
||||
self.clip_image_tensor = transforms.ToTensor()(img)
|
||||
|
||||
if self.clip_image_processor is not None:
|
||||
# run it
|
||||
tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16)
|
||||
clip_out = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
self.clip_image_tensor = clip_out.squeeze(0).clone().detach()
|
||||
|
||||
def cleanup_clip_image(self: 'FileItemDTO'):
|
||||
self.clip_image_tensor = None
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@ from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
@@ -163,7 +164,9 @@ class IPAdapter(torch.nn.Module):
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.device = self.sd_ref().unet.device
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
self.preprocessor: Optional[CLIPImagePreProcessor] = None
|
||||
self.input_size = 224
|
||||
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
@@ -176,7 +179,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = ViTFeatureExtractor()
|
||||
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'safe':
|
||||
try:
|
||||
self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -220,6 +224,26 @@ class IPAdapter(torch.nn.Module):
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
else:
|
||||
raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
|
||||
|
||||
self.input_size = self.image_encoder.config.image_size
|
||||
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.clip_image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
preprocessor_input_size = self.image_encoder.config.image_size * 3
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
|
||||
|
||||
self.preprocessor = CLIPImagePreProcessor(
|
||||
input_size=preprocessor_input_size,
|
||||
clip_input_size=self.image_encoder.config.image_size,
|
||||
downscale_factor=6
|
||||
)
|
||||
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
if adapter_config.type == 'ip':
|
||||
@@ -232,12 +256,9 @@ class IPAdapter(torch.nn.Module):
|
||||
elif adapter_config.type == 'ip+':
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1]
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
# if self.config.image_encoder_arch == 'safe':
|
||||
# embedding_dim = self.config.safe_tokens
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
|
||||
# ip-adapter-plus
|
||||
image_proj_model = Resampler(
|
||||
dim=dim,
|
||||
@@ -312,6 +333,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.image_encoder.to(*args, **kwargs)
|
||||
self.image_proj_model.to(*args, **kwargs)
|
||||
self.adapter_modules.to(*args, **kwargs)
|
||||
if self.preprocessor is not None:
|
||||
self.preprocessor.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
|
||||
@@ -320,6 +343,8 @@ class IPAdapter(torch.nn.Module):
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"])
|
||||
if self.preprocessor is not None and 'preprocessor' in state_dict:
|
||||
self.preprocessor.load_state_dict(state_dict["preprocessor"])
|
||||
|
||||
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
|
||||
# self.load_ip_adapter(state_dict)
|
||||
@@ -330,6 +355,8 @@ class IPAdapter(torch.nn.Module):
|
||||
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["image_encoder"] = self.image_encoder.state_dict()
|
||||
if self.preprocessor is not None:
|
||||
state_dict["preprocessor"] = self.preprocessor.state_dict()
|
||||
return state_dict
|
||||
|
||||
def get_scale(self):
|
||||
@@ -341,53 +368,74 @@ class IPAdapter(torch.nn.Module):
|
||||
if isinstance(attn_processor, CustomIPAttentionProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@torch.no_grad()
|
||||
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
|
||||
drop=False) -> torch.Tensor:
|
||||
# todo: add support for sdxl
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
return clip_image_embeds
|
||||
# @torch.no_grad()
|
||||
# def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
|
||||
# drop=False) -> torch.Tensor:
|
||||
# # todo: add support for sdxl
|
||||
# if isinstance(pil_image, Image.Image):
|
||||
# pil_image = [pil_image]
|
||||
# clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
# clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
# if drop:
|
||||
# clip_image = clip_image * 0
|
||||
# clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
# return clip_image_embeds
|
||||
|
||||
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False,
|
||||
is_training=False) -> torch.Tensor:
|
||||
def get_clip_image_embeds_from_tensors(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
drop=False,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False
|
||||
) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
# tensors should be 0-1
|
||||
# todo: add support for sdxl
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
self.image_encoder.train()
|
||||
clip_output = self.image_encoder(clip_image.requires_grad_(True)
|
||||
, output_hidden_states=True)
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
if self.preprocessor is not None:
|
||||
clip_image = self.preprocessor(clip_image)
|
||||
clip_output = self.image_encoder(
|
||||
clip_image,
|
||||
output_hidden_states=True
|
||||
)
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
||||
if self.preprocessor is not None:
|
||||
clip_image = self.preprocessor(clip_image)
|
||||
clip_output = self.image_encoder(
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
if self.config.type.startswith('ip+'):
|
||||
# they skip last layer for ip+
|
||||
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
return clip_image_embeds
|
||||
|
||||
# use drop for prompt dropout, or negatives
|
||||
@@ -403,13 +451,8 @@ class IPAdapter(torch.nn.Module):
|
||||
yield from self.image_proj_model.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
# if self.config.train_image_encoder:
|
||||
# yield from self.image_encoder.parameters(recurse)
|
||||
# self.image_encoder.train()
|
||||
# else:
|
||||
# for attn_processor in self.adapter_modules:
|
||||
# yield from attn_processor.parameters(recurse)
|
||||
# yield from self.image_proj_model.parameters(recurse)
|
||||
if self.preprocessor is not None:
|
||||
yield from self.preprocessor.parameters(recurse)
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
@@ -417,6 +460,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
||||
if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict:
|
||||
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.image_encoder.gradient_checkpointing = True
|
||||
|
||||
113
toolkit/models/clip_pre_processor.py
Normal file
113
toolkit/models/clip_pre_processor.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class UpsampleBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv_in = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
||||
nn.GELU()
|
||||
)
|
||||
self.conv_up = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
self.conv_out = nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.conv_up(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class CLIPImagePreProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size=672,
|
||||
clip_input_size=224,
|
||||
downscale_factor: int = 6,
|
||||
channels=None, # 108
|
||||
):
|
||||
super().__init__()
|
||||
# make sure they are evenly divisible
|
||||
assert input_size % clip_input_size == 0
|
||||
in_channels = 3
|
||||
|
||||
self.input_size = input_size
|
||||
self.clip_input_size = clip_input_size
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 6 ** 2 = 108
|
||||
|
||||
if channels is None:
|
||||
channels = subpixel_channels
|
||||
|
||||
upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 6 / (672 / 224) = 2
|
||||
|
||||
num_upsample_blocks = int(upscale_factor // 2) # 2 // 2 = 1
|
||||
|
||||
# do a pooling layer to downscale the input to 1/3 of the size
|
||||
# (bs, 3, 672, 672) -> (bs, 3, 224, 224)
|
||||
kernel_size = input_size // clip_input_size
|
||||
self.res_down = nn.AvgPool2d(
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size
|
||||
) # (bs, 3, 672, 672) -> (bs, 3, 224, 224)
|
||||
|
||||
# make a blending for output residual with near 0 weight
|
||||
self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 672, 672) -> (bs, 108, 112, 112)
|
||||
|
||||
self.conv_in = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
subpixel_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
),
|
||||
nn.GELU()
|
||||
) # (bs, 108, 112, 112) -> (bs, 108, 112, 112)
|
||||
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
current_channels = channels
|
||||
for _ in range(num_upsample_blocks):
|
||||
out_channels = current_channels // 2
|
||||
self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
|
||||
current_channels = out_channels
|
||||
|
||||
# (bs, 108, 112, 112) -> (bs, 54, 224, 224)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
current_channels,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
) # (bs, 54, 224, 224) -> (bs, 3, 224, 224)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# resize to input_size x input_size
|
||||
x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
|
||||
|
||||
res = self.res_down(x)
|
||||
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
for up in self.upsample_blocks:
|
||||
x = up(x)
|
||||
x = self.conv_out(x)
|
||||
# blend residual
|
||||
x = x * self.res_blend + res
|
||||
return x
|
||||
Reference in New Issue
Block a user