mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added initial support for pixtral vision as a vision encoder.
This commit is contained in:
@@ -20,6 +20,10 @@ RUN apt-get install -y tmux nvtop htop
|
||||
|
||||
RUN pip install jupyterlab
|
||||
|
||||
# mask workspace
|
||||
RUN mkdir /workspace
|
||||
|
||||
|
||||
# symlink app to workspace
|
||||
RUN ln -s /app/ai-toolkit /workspace/ai-toolkit
|
||||
|
||||
|
||||
492
toolkit/models/pixtral_vision.py
Normal file
492
toolkit/models/pixtral_vision.py
Normal file
@@ -0,0 +1,492 @@
|
||||
from typing import List, Optional, Tuple, Any, Union
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
import json
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
|
||||
|
||||
|
||||
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
|
||||
values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
|
||||
return keys, values
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis[:, None, :]
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
head_dim: int,
|
||||
n_kv_heads: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads: int = n_heads
|
||||
self.head_dim: int = head_dim
|
||||
self.n_kv_heads: int = n_kv_heads
|
||||
|
||||
self.repeats = self.n_heads // self.n_kv_heads
|
||||
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[BlockDiagonalMask] = None,
|
||||
) -> torch.Tensor:
|
||||
assert mask is None or cache is None
|
||||
seqlen_sum, _ = x.shape
|
||||
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
|
||||
xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
|
||||
xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
if cache is None:
|
||||
key, val = xk, xv
|
||||
elif cache.prefill:
|
||||
key, val = cache.interleave_kv(xk, xv)
|
||||
cache.update(xk, xv)
|
||||
else:
|
||||
cache.update(xk, xv)
|
||||
key, val = cache.key, cache.value
|
||||
key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
|
||||
val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
|
||||
|
||||
# Repeat keys and values to match number of query heads
|
||||
key, val = repeat_kv(key, val, self.repeats, dim=1)
|
||||
|
||||
# xformers requires (B=1, S, H, D)
|
||||
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
|
||||
output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
|
||||
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
|
||||
|
||||
assert isinstance(output, torch.Tensor)
|
||||
|
||||
return self.wo(output) # type: ignore
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
head_dim: int,
|
||||
norm_eps: float,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.dim = dim
|
||||
self.attention = Attention(
|
||||
dim=dim,
|
||||
n_heads=n_heads,
|
||||
head_dim=head_dim,
|
||||
n_kv_heads=n_kv_heads,
|
||||
)
|
||||
self.attention_norm = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.feed_forward: nn.Module
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
cache: Optional[Any] = None,
|
||||
mask: Optional[BlockDiagonalMask] = None,
|
||||
) -> torch.Tensor:
|
||||
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward.forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionEncoderArgs:
|
||||
hidden_size: int
|
||||
num_channels: int
|
||||
image_size: int
|
||||
patch_size: int
|
||||
intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
rope_theta: float = 1e4 # for rope-2D
|
||||
image_token_id: int = 10
|
||||
|
||||
|
||||
def precompute_freqs_cis_2d(
|
||||
dim: int,
|
||||
height: int,
|
||||
width: int,
|
||||
theta: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
|
||||
(height, width) position tuples
|
||||
"""
|
||||
# (dim / 2) frequency bases
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
||||
|
||||
h = torch.arange(height, device=freqs.device)
|
||||
w = torch.arange(width, device=freqs.device)
|
||||
|
||||
freqs_h = torch.outer(h, freqs[::2]).float()
|
||||
freqs_w = torch.outer(w, freqs[1::2]).float()
|
||||
freqs_2d = torch.cat(
|
||||
[
|
||||
freqs_h[:, None, :].repeat(1, width, 1),
|
||||
freqs_w[None, :, :].repeat(height, 1, 1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
||||
|
||||
|
||||
def position_meshgrid(
|
||||
patch_embeds_list: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
positions = torch.cat(
|
||||
[
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(p.shape[-2]),
|
||||
torch.arange(p.shape[-1]),
|
||||
indexing="ij",
|
||||
),
|
||||
dim=-1,
|
||||
).reshape(-1, 2)
|
||||
for p in patch_embeds_list
|
||||
]
|
||||
)
|
||||
return positions
|
||||
|
||||
|
||||
class PixtralVisionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
intermediate_size: int = 4096,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
rope_theta: float = 1e4, # for rope-2D
|
||||
image_token_id: int = 10,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.args = VisionEncoderArgs(
|
||||
hidden_size=hidden_size,
|
||||
num_channels=num_channels,
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
rope_theta=rope_theta,
|
||||
image_token_id=image_token_id,
|
||||
)
|
||||
args = self.args
|
||||
self.patch_conv = nn.Conv2d(
|
||||
in_channels=args.num_channels,
|
||||
out_channels=args.hidden_size,
|
||||
kernel_size=args.patch_size,
|
||||
stride=args.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
|
||||
self.transformer = VisionTransformerBlocks(args)
|
||||
|
||||
head_dim = self.args.hidden_size // self.args.num_attention_heads
|
||||
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
||||
self._freqs_cis: Optional[torch.Tensor] = None
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
model_folder = pretrained_model_name_or_path
|
||||
else:
|
||||
model_folder = snapshot_download(pretrained_model_name_or_path)
|
||||
|
||||
# make sure there is a config
|
||||
if not os.path.exists(os.path.join(model_folder, "config.json")):
|
||||
raise ValueError(f"Could not find config.json in {model_folder}")
|
||||
|
||||
# load config
|
||||
with open(os.path.join(model_folder, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = PixtralVisionEncoder(**config)
|
||||
|
||||
# see if there is a state_dict
|
||||
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
|
||||
state_dict = load_file(os.path.join(model_folder, "model.safetensors"))
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
return model
|
||||
|
||||
@property
|
||||
def max_patches_per_side(self) -> int:
|
||||
return self.args.image_size // self.args.patch_size
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def freqs_cis(self) -> torch.Tensor:
|
||||
if self._freqs_cis is None:
|
||||
self._freqs_cis = precompute_freqs_cis_2d(
|
||||
dim=self.args.hidden_size // self.args.num_attention_heads,
|
||||
height=self.max_patches_per_side,
|
||||
width=self.max_patches_per_side,
|
||||
theta=self.args.rope_theta,
|
||||
)
|
||||
|
||||
if self._freqs_cis.device != self.device:
|
||||
self._freqs_cis = self._freqs_cis.to(device=self.device)
|
||||
|
||||
return self._freqs_cis
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
images: list of N_img images of variable sizes, each of shape (C, H, W)
|
||||
|
||||
Returns:
|
||||
image_features: tensor of token features for all tokens of all images of
|
||||
shape (N_toks, D)
|
||||
"""
|
||||
assert isinstance(images, list), f"Expected list of images, got {type(images)}"
|
||||
assert all(len(img.shape) == 3 for img in
|
||||
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
|
||||
# pass images through initial convolution independently
|
||||
patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
|
||||
|
||||
# flatten to a single sequence
|
||||
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
|
||||
patch_embeds = self.ln_pre(patch_embeds)
|
||||
|
||||
# positional embeddings
|
||||
positions = position_meshgrid(patch_embeds_list).to(self.device)
|
||||
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
|
||||
|
||||
# pass through Transformer with a block diagonal mask delimiting images
|
||||
mask = BlockDiagonalMask.from_seqlens(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||
)
|
||||
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
||||
|
||||
# remove batch dimension of the single sequence
|
||||
return out # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class VisionLanguageAdapter(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.w_in = nn.Linear(
|
||||
in_dim,
|
||||
out_dim,
|
||||
bias=True,
|
||||
)
|
||||
self.gelu = nn.GELU()
|
||||
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class VisionTransformerBlocks(nn.Module):
|
||||
def __init__(self, args: VisionEncoderArgs):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for _ in range(args.num_hidden_layers):
|
||||
self.layers.append(
|
||||
TransformerBlock(
|
||||
dim=args.hidden_size,
|
||||
hidden_dim=args.intermediate_size,
|
||||
n_heads=args.num_attention_heads,
|
||||
n_kv_heads=args.num_attention_heads,
|
||||
head_dim=args.hidden_size // args.num_attention_heads,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: BlockDiagonalMask,
|
||||
freqs_cis: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=mask, freqs_cis=freqs_cis)
|
||||
return x
|
||||
|
||||
|
||||
DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073] # RGB
|
||||
DATASET_STD = [0.26862954, 0.26130258, 0.27577711] # RGB
|
||||
|
||||
|
||||
def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalize a tensor image with mean and standard deviation.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): Image to be normalized, shape (C, H, W), values in [0, 1].
|
||||
mean (torch.Tensor): Mean for each channel.
|
||||
std (torch.Tensor): Standard deviation for each channel.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Normalized image with shape (C, H, W).
|
||||
"""
|
||||
assert image.shape[0] == len(mean) == len(std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
|
||||
|
||||
# Reshape mean and std to (C, 1, 1) for broadcasting
|
||||
mean = mean.view(-1, 1, 1)
|
||||
std = std.view(-1, 1, 1)
|
||||
|
||||
return (image - mean) / std
|
||||
|
||||
|
||||
def transform_image(image: torch.Tensor, new_size: tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Resize and normalize the input image.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): Input image tensor of shape (C, H, W), values in [0, 1].
|
||||
new_size (tuple[int, int]): Target size (height, width) for resizing.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Resized and normalized image tensor of shape (C, new_H, new_W).
|
||||
"""
|
||||
# Resize the image
|
||||
resized_image = torch.nn.functional.interpolate(
|
||||
image.unsqueeze(0),
|
||||
size=new_size,
|
||||
mode='bicubic',
|
||||
align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# Normalize the image
|
||||
normalized_image = normalize(
|
||||
resized_image,
|
||||
torch.tensor(DATASET_MEAN, device=image.device, dtype=image.dtype),
|
||||
torch.tensor(DATASET_STD, device=image.device, dtype=image.dtype)
|
||||
)
|
||||
|
||||
return normalized_image
|
||||
|
||||
|
||||
class PixtralVisionImagePreprocessor:
|
||||
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
|
||||
self.image_patch_size = image_patch_size
|
||||
self.max_image_size = max_image_size
|
||||
self.image_token = 10
|
||||
|
||||
def _image_to_num_tokens(self, img: torch.Tensor) -> Tuple[int, int]:
|
||||
w: Union[int, float]
|
||||
h: Union[int, float]
|
||||
|
||||
w, h = img.shape[-1], img.shape[-2]
|
||||
|
||||
ratio = max(h / self.max_image_size, w / self.max_image_size)
|
||||
if ratio > 1:
|
||||
w = round(w / ratio)
|
||||
h = round(h / ratio)
|
||||
|
||||
width_tokens = (w - 1) // self.image_patch_size + 1
|
||||
height_tokens = (h - 1) // self.image_patch_size + 1
|
||||
|
||||
return width_tokens, height_tokens
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts ImageChunks to numpy image arrays and image token ids
|
||||
|
||||
Args:
|
||||
image torch tensor with values 0-1 and shape of (C, H, W)
|
||||
|
||||
Returns:
|
||||
processed_image: tensor of token features for all tokens of all images of
|
||||
"""
|
||||
# should not have batch
|
||||
if len(image.shape) == 4:
|
||||
raise ValueError(f"Expected image with shape (C, H, W), got {image.shape}")
|
||||
|
||||
if image.min() < 0.0 or image.max() > 1.0:
|
||||
raise ValueError(f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
|
||||
|
||||
w, h = self._image_to_num_tokens(image)
|
||||
assert w > 0
|
||||
assert h > 0
|
||||
|
||||
new_image_size = (
|
||||
w * self.image_patch_size,
|
||||
h * self.image_patch_size,
|
||||
)
|
||||
|
||||
processed_image = transform_image(image, new_image_size)
|
||||
|
||||
return processed_image
|
||||
Reference in New Issue
Block a user