mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
619 lines
20 KiB
Python
619 lines
20 KiB
Python
import math
|
|
from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
from dataclasses import dataclass
|
|
from huggingface_hub import snapshot_download
|
|
from safetensors.torch import load_file
|
|
import json
|
|
|
|
if TYPE_CHECKING:
|
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
|
|
|
|
class RMSNorm(torch.nn.Module):
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
|
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
output = self._norm(x.float()).type_as(x)
|
|
return output * self.weight
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim: int, hidden_dim: int, **kwargs):
|
|
super().__init__()
|
|
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# type: ignore
|
|
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
|
|
values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
|
|
return keys, values
|
|
|
|
|
|
def apply_rotary_emb(
|
|
xq: torch.Tensor,
|
|
xk: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
freqs_cis = freqs_cis[:, None, :]
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
|
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
n_heads: int,
|
|
head_dim: int,
|
|
n_kv_heads: int,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.n_heads: int = n_heads
|
|
self.head_dim: int = head_dim
|
|
self.n_kv_heads: int = n_kv_heads
|
|
|
|
self.repeats = self.n_heads // self.n_kv_heads
|
|
|
|
self.scale = self.head_dim ** -0.5
|
|
|
|
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
|
|
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
cache: Optional[Any] = None,
|
|
mask: Optional['BlockDiagonalMask'] = None,
|
|
) -> torch.Tensor:
|
|
from xformers.ops.fmha import memory_efficient_attention
|
|
assert mask is None or cache is None
|
|
seqlen_sum, _ = x.shape
|
|
|
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
|
xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
|
|
xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
|
|
xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
|
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
|
|
|
if cache is None:
|
|
key, val = xk, xv
|
|
elif cache.prefill:
|
|
key, val = cache.interleave_kv(xk, xv)
|
|
cache.update(xk, xv)
|
|
else:
|
|
cache.update(xk, xv)
|
|
key, val = cache.key, cache.value
|
|
key = key.view(seqlen_sum * cache.max_seq_len,
|
|
self.n_kv_heads, self.head_dim)
|
|
val = val.view(seqlen_sum * cache.max_seq_len,
|
|
self.n_kv_heads, self.head_dim)
|
|
|
|
# Repeat keys and values to match number of query heads
|
|
key, val = repeat_kv(key, val, self.repeats, dim=1)
|
|
|
|
# xformers requires (B=1, S, H, D)
|
|
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
|
|
output = memory_efficient_attention(
|
|
xq, key, val, mask if cache is None else cache.mask)
|
|
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
|
|
|
|
assert isinstance(output, torch.Tensor)
|
|
|
|
return self.wo(output) # type: ignore
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
hidden_dim: int,
|
|
n_heads: int,
|
|
n_kv_heads: int,
|
|
head_dim: int,
|
|
norm_eps: float,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
self.dim = dim
|
|
self.attention = Attention(
|
|
dim=dim,
|
|
n_heads=n_heads,
|
|
head_dim=head_dim,
|
|
n_kv_heads=n_kv_heads,
|
|
)
|
|
self.attention_norm = RMSNorm(dim, eps=norm_eps)
|
|
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
|
|
|
self.feed_forward: nn.Module
|
|
self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
cache: Optional[Any] = None,
|
|
mask: Optional['BlockDiagonalMask'] = None,
|
|
) -> torch.Tensor:
|
|
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
|
|
h = x + r
|
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
@dataclass
|
|
class VisionEncoderArgs:
|
|
hidden_size: int
|
|
num_channels: int
|
|
image_size: int
|
|
patch_size: int
|
|
intermediate_size: int
|
|
num_hidden_layers: int
|
|
num_attention_heads: int
|
|
rope_theta: float = 1e4 # for rope-2D
|
|
image_token_id: int = 10
|
|
|
|
|
|
def precompute_freqs_cis_2d(
|
|
dim: int,
|
|
height: int,
|
|
width: int,
|
|
theta: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
|
|
(height, width) position tuples
|
|
"""
|
|
# (dim / 2) frequency bases
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
|
h = torch.arange(height, device=freqs.device)
|
|
w = torch.arange(width, device=freqs.device)
|
|
|
|
freqs_h = torch.outer(h, freqs[::2]).float()
|
|
freqs_w = torch.outer(w, freqs[1::2]).float()
|
|
freqs_2d = torch.cat(
|
|
[
|
|
freqs_h[:, None, :].repeat(1, width, 1),
|
|
freqs_w[None, :, :].repeat(height, 1, 1),
|
|
],
|
|
dim=-1,
|
|
)
|
|
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
|
|
|
|
|
def position_meshgrid(
|
|
patch_embeds_list: list[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
positions = torch.cat(
|
|
[
|
|
torch.stack(
|
|
torch.meshgrid(
|
|
torch.arange(p.shape[-2]),
|
|
torch.arange(p.shape[-1]),
|
|
indexing="ij",
|
|
),
|
|
dim=-1,
|
|
).reshape(-1, 2)
|
|
for p in patch_embeds_list
|
|
]
|
|
)
|
|
return positions
|
|
|
|
|
|
class PixtralVisionEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int = 1024,
|
|
num_channels: int = 3,
|
|
image_size: int = 1024,
|
|
patch_size: int = 16,
|
|
intermediate_size: int = 4096,
|
|
num_hidden_layers: int = 24,
|
|
num_attention_heads: int = 16,
|
|
rope_theta: float = 1e4, # for rope-2D
|
|
image_token_id: int = 10,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.args = VisionEncoderArgs(
|
|
hidden_size=hidden_size,
|
|
num_channels=num_channels,
|
|
image_size=image_size,
|
|
patch_size=patch_size,
|
|
intermediate_size=intermediate_size,
|
|
num_hidden_layers=num_hidden_layers,
|
|
num_attention_heads=num_attention_heads,
|
|
rope_theta=rope_theta,
|
|
image_token_id=image_token_id,
|
|
)
|
|
args = self.args
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels=args.num_channels,
|
|
out_channels=args.hidden_size,
|
|
kernel_size=args.patch_size,
|
|
stride=args.patch_size,
|
|
bias=False,
|
|
)
|
|
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
|
|
self.transformer = VisionTransformerBlocks(args)
|
|
|
|
head_dim = self.args.hidden_size // self.args.num_attention_heads
|
|
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
|
self._freqs_cis: Optional[torch.Tensor] = None
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
|
|
if os.path.isdir(pretrained_model_name_or_path):
|
|
model_folder = pretrained_model_name_or_path
|
|
else:
|
|
model_folder = snapshot_download(pretrained_model_name_or_path)
|
|
|
|
# make sure there is a config
|
|
if not os.path.exists(os.path.join(model_folder, "config.json")):
|
|
raise ValueError(f"Could not find config.json in {model_folder}")
|
|
|
|
# load config
|
|
with open(os.path.join(model_folder, "config.json"), "r") as f:
|
|
config = json.load(f)
|
|
|
|
model = cls(**config)
|
|
|
|
# see if there is a state_dict
|
|
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
|
|
state_dict = load_file(os.path.join(
|
|
model_folder, "model.safetensors"))
|
|
model.load_state_dict(state_dict)
|
|
|
|
return model
|
|
|
|
@property
|
|
def max_patches_per_side(self) -> int:
|
|
return self.args.image_size // self.args.patch_size
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def freqs_cis(self) -> torch.Tensor:
|
|
if self._freqs_cis is None:
|
|
self._freqs_cis = precompute_freqs_cis_2d(
|
|
dim=self.args.hidden_size // self.args.num_attention_heads,
|
|
height=self.max_patches_per_side,
|
|
width=self.max_patches_per_side,
|
|
theta=self.args.rope_theta,
|
|
)
|
|
|
|
if self._freqs_cis.device != self.device:
|
|
self._freqs_cis = self._freqs_cis.to(device=self.device)
|
|
|
|
return self._freqs_cis
|
|
|
|
def forward(
|
|
self,
|
|
images: List[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
"""
|
|
Args:
|
|
images: list of N_img images of variable sizes, each of shape (C, H, W)
|
|
|
|
Returns:
|
|
image_features: tensor of token features for all tokens of all images of
|
|
shape (N_toks, D)
|
|
"""
|
|
assert isinstance(
|
|
images, list), f"Expected list of images, got {type(images)}"
|
|
assert all(len(img.shape) == 3 for img in
|
|
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
|
|
# pass images through initial convolution independently
|
|
patch_embeds_list = [self.patch_conv(
|
|
img.unsqueeze(0)).squeeze(0) for img in images]
|
|
|
|
# flatten to a single sequence
|
|
patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
|
|
for p in patch_embeds_list], dim=0)
|
|
patch_embeds = self.ln_pre(patch_embeds)
|
|
|
|
# positional embeddings
|
|
positions = position_meshgrid(patch_embeds_list).to(self.device)
|
|
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
|
|
|
|
# pass through Transformer with a block diagonal mask delimiting images
|
|
mask = BlockDiagonalMask.from_seqlens(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
|
)
|
|
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
|
|
|
# remove batch dimension of the single sequence
|
|
return out # type: ignore[no-any-return]
|
|
|
|
|
|
class VisionLanguageAdapter(nn.Module):
|
|
def __init__(self, in_dim: int, out_dim: int):
|
|
super().__init__()
|
|
self.w_in = nn.Linear(
|
|
in_dim,
|
|
out_dim,
|
|
bias=True,
|
|
)
|
|
self.gelu = nn.GELU()
|
|
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# type: ignore[no-any-return]
|
|
return self.w_out(self.gelu(self.w_in(x)))
|
|
|
|
|
|
class VisionTransformerBlocks(nn.Module):
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList()
|
|
for _ in range(args.num_hidden_layers):
|
|
self.layers.append(
|
|
TransformerBlock(
|
|
dim=args.hidden_size,
|
|
hidden_dim=args.intermediate_size,
|
|
n_heads=args.num_attention_heads,
|
|
n_kv_heads=args.num_attention_heads,
|
|
head_dim=args.hidden_size // args.num_attention_heads,
|
|
norm_eps=1e-5,
|
|
)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: 'BlockDiagonalMask',
|
|
freqs_cis: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
for layer in self.layers:
|
|
x = layer(x, mask=mask, freqs_cis=freqs_cis)
|
|
return x
|
|
|
|
|
|
DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073] # RGB
|
|
DATASET_STD = [0.26862954, 0.26130258, 0.27577711] # RGB
|
|
|
|
|
|
def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Normalize a tensor image with mean and standard deviation.
|
|
|
|
Args:
|
|
image (torch.Tensor): Image to be normalized, shape (C, H, W), values in [0, 1].
|
|
mean (torch.Tensor): Mean for each channel.
|
|
std (torch.Tensor): Standard deviation for each channel.
|
|
|
|
Returns:
|
|
torch.Tensor: Normalized image with shape (C, H, W).
|
|
"""
|
|
assert image.shape[0] == len(mean) == len(
|
|
std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
|
|
|
|
# Reshape mean and std to (C, 1, 1) for broadcasting
|
|
mean = mean.view(-1, 1, 1)
|
|
std = std.view(-1, 1, 1)
|
|
|
|
return (image - mean) / std
|
|
|
|
|
|
def transform_image(image: torch.Tensor, new_size: tuple[int, int]) -> torch.Tensor:
|
|
"""
|
|
Resize and normalize the input image.
|
|
|
|
Args:
|
|
image (torch.Tensor): Input image tensor of shape (C, H, W), values in [0, 1].
|
|
new_size (tuple[int, int]): Target size (height, width) for resizing.
|
|
|
|
Returns:
|
|
torch.Tensor: Resized and normalized image tensor of shape (C, new_H, new_W).
|
|
"""
|
|
# Resize the image
|
|
resized_image = torch.nn.functional.interpolate(
|
|
image.unsqueeze(0),
|
|
size=new_size,
|
|
mode='bicubic',
|
|
align_corners=False
|
|
).squeeze(0)
|
|
|
|
# Normalize the image
|
|
normalized_image = normalize(
|
|
resized_image,
|
|
torch.tensor(DATASET_MEAN, device=image.device, dtype=image.dtype),
|
|
torch.tensor(DATASET_STD, device=image.device, dtype=image.dtype)
|
|
)
|
|
|
|
return normalized_image
|
|
|
|
|
|
class PixtralVisionImagePreprocessor:
|
|
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
|
|
self.image_patch_size = image_patch_size
|
|
self.max_image_size = max_image_size
|
|
self.image_token = 10
|
|
|
|
def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]:
|
|
w: Union[int, float]
|
|
h: Union[int, float]
|
|
|
|
if max_image_size is None:
|
|
max_image_size = self.max_image_size
|
|
|
|
w, h = img.shape[-1], img.shape[-2]
|
|
|
|
# originally, pixtral used the largest of the 2 dimensions, but we
|
|
# will use the base size of the image based on number of pixels.
|
|
# ratio = max(h / self.max_image_size, w / self.max_image_size) # original
|
|
|
|
base_size = int(math.sqrt(w * h))
|
|
ratio = base_size / max_image_size
|
|
if ratio > 1:
|
|
w = round(w / ratio)
|
|
h = round(h / ratio)
|
|
|
|
width_tokens = (w - 1) // self.image_patch_size + 1
|
|
height_tokens = (h - 1) // self.image_patch_size + 1
|
|
|
|
return width_tokens, height_tokens
|
|
|
|
def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor:
|
|
"""
|
|
Converts ImageChunks to numpy image arrays and image token ids
|
|
|
|
Args:
|
|
image torch tensor with values 0-1 and shape of (C, H, W)
|
|
|
|
Returns:
|
|
processed_image: tensor of token features for all tokens of all images of
|
|
"""
|
|
# should not have batch
|
|
if len(image.shape) == 4:
|
|
raise ValueError(
|
|
f"Expected image with shape (C, H, W), got {image.shape}")
|
|
|
|
if image.min() < 0.0 or image.max() > 1.0:
|
|
raise ValueError(
|
|
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
|
|
|
if max_image_size is None:
|
|
max_image_size = self.max_image_size
|
|
|
|
w, h = self._image_to_num_tokens(image, max_image_size=max_image_size)
|
|
assert w > 0
|
|
assert h > 0
|
|
|
|
new_image_size = (
|
|
w * self.image_patch_size,
|
|
h * self.image_patch_size,
|
|
)
|
|
|
|
processed_image = transform_image(image, new_image_size)
|
|
|
|
return processed_image
|
|
|
|
|
|
class PixtralVisionImagePreprocessorCompatibleReturn:
|
|
def __init__(self, pixel_values) -> None:
|
|
self.pixel_values = pixel_values
|
|
|
|
|
|
# Compatable version with ai toolkit flow
|
|
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
|
|
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
|
|
super().__init__(
|
|
image_patch_size=image_patch_size,
|
|
max_image_size=max_image_size
|
|
)
|
|
self.size = {
|
|
'height': max_image_size,
|
|
'width': max_image_size
|
|
}
|
|
self.max_image_size = max_image_size
|
|
self.image_mean = DATASET_MEAN
|
|
self.image_std = DATASET_STD
|
|
|
|
def __call__(
|
|
self,
|
|
images,
|
|
return_tensors="pt",
|
|
do_resize=True,
|
|
do_rescale=False,
|
|
max_image_size=None,
|
|
) -> torch.Tensor:
|
|
if max_image_size is None:
|
|
max_image_size = self.max_image_size
|
|
out_stack = []
|
|
if len(images.shape) == 3:
|
|
images = images.unsqueeze(0)
|
|
for i in range(images.shape[0]):
|
|
image = images[i]
|
|
processed_image = super().__call__(image, max_image_size=max_image_size)
|
|
out_stack.append(processed_image)
|
|
|
|
output = torch.stack(out_stack, dim=0)
|
|
return PixtralVisionImagePreprocessorCompatibleReturn(output)
|
|
|
|
|
|
class PixtralVisionEncoderCompatibleReturn:
|
|
def __init__(self, hidden_states) -> None:
|
|
self.hidden_states = hidden_states
|
|
|
|
|
|
class PixtralVisionEncoderCompatibleConfig:
|
|
def __init__(self):
|
|
self.image_size = 1024
|
|
self.hidden_size = 1024
|
|
self.patch_size = 16
|
|
|
|
|
|
class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int = 1024,
|
|
num_channels: int = 3,
|
|
image_size: int = 1024,
|
|
patch_size: int = 16,
|
|
intermediate_size: int = 4096,
|
|
num_hidden_layers: int = 24,
|
|
num_attention_heads: int = 16,
|
|
rope_theta: float = 1e4, # for rope-2D
|
|
image_token_id: int = 10,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
hidden_size=hidden_size,
|
|
num_channels=num_channels,
|
|
image_size=image_size,
|
|
patch_size=patch_size,
|
|
intermediate_size=intermediate_size,
|
|
num_hidden_layers=num_hidden_layers,
|
|
num_attention_heads=num_attention_heads,
|
|
rope_theta=rope_theta,
|
|
image_token_id=image_token_id,
|
|
)
|
|
self.config = PixtralVisionEncoderCompatibleConfig()
|
|
|
|
def forward(
|
|
self,
|
|
images,
|
|
output_hidden_states=True,
|
|
) -> torch.Tensor:
|
|
out_stack = []
|
|
if len(images.shape) == 3:
|
|
images = images.unsqueeze(0)
|
|
for i in range(images.shape[0]):
|
|
image = images[i]
|
|
# must be in an array
|
|
image_output = super().forward([image])
|
|
out_stack.append(image_output)
|
|
|
|
output = torch.stack(out_stack, dim=0)
|
|
return PixtralVisionEncoderCompatibleReturn([output])
|