Added initial direct vision pixtral support

This commit is contained in:
Jaret Burkett
2024-09-28 10:47:51 -06:00
parent 86b5938cf3
commit 58537fc92b
7 changed files with 165 additions and 23 deletions

View File

@@ -33,7 +33,8 @@ class FeedForward(nn.Module):
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
# type: ignore
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -103,15 +104,18 @@ class Attention(nn.Module):
else:
cache.update(xk, xv)
key, val = cache.key, cache.value
key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
key = key.view(seqlen_sum * cache.max_seq_len,
self.n_kv_heads, self.head_dim)
val = val.view(seqlen_sum * cache.max_seq_len,
self.n_kv_heads, self.head_dim)
# Repeat keys and values to match number of query heads
key, val = repeat_kv(key, val, self.repeats, dim=1)
# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
output = memory_efficient_attention(
xq, key, val, mask if cache is None else cache.mask)
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
assert isinstance(output, torch.Tensor)
@@ -260,8 +264,8 @@ class PixtralVisionEncoder(nn.Module):
assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: Optional[torch.Tensor] = None
@staticmethod
def from_pretrained(pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
if os.path.isdir(pretrained_model_name_or_path):
model_folder = pretrained_model_name_or_path
else:
@@ -275,11 +279,12 @@ class PixtralVisionEncoder(nn.Module):
with open(os.path.join(model_folder, "config.json"), "r") as f:
config = json.load(f)
model = PixtralVisionEncoder(**config)
model = cls(**config)
# see if there is a state_dict
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
state_dict = load_file(os.path.join(model_folder, "model.safetensors"))
state_dict = load_file(os.path.join(
model_folder, "model.safetensors"))
model.load_state_dict(state_dict)
return model
@@ -319,14 +324,17 @@ class PixtralVisionEncoder(nn.Module):
image_features: tensor of token features for all tokens of all images of
shape (N_toks, D)
"""
assert isinstance(images, list), f"Expected list of images, got {type(images)}"
assert isinstance(
images, list), f"Expected list of images, got {type(images)}"
assert all(len(img.shape) == 3 for img in
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
# pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
patch_embeds_list = [self.patch_conv(
img.unsqueeze(0)).squeeze(0) for img in images]
# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
for p in patch_embeds_list], dim=0)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
@@ -355,7 +363,8 @@ class VisionLanguageAdapter(nn.Module):
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]
# type: ignore[no-any-return]
return self.w_out(self.gelu(self.w_in(x)))
class VisionTransformerBlocks(nn.Module):
@@ -401,7 +410,8 @@ def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> tor
Returns:
torch.Tensor: Normalized image with shape (C, H, W).
"""
assert image.shape[0] == len(mean) == len(std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
assert image.shape[0] == len(mean) == len(
std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
# Reshape mean and std to (C, 1, 1) for broadcasting
mean = mean.view(-1, 1, 1)
@@ -473,10 +483,12 @@ class PixtralVisionImagePreprocessor:
"""
# should not have batch
if len(image.shape) == 4:
raise ValueError(f"Expected image with shape (C, H, W), got {image.shape}")
raise ValueError(
f"Expected image with shape (C, H, W), got {image.shape}")
if image.min() < 0.0 or image.max() > 1.0:
raise ValueError(f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
raise ValueError(
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
w, h = self._image_to_num_tokens(image)
assert w > 0
@@ -490,3 +502,98 @@ class PixtralVisionImagePreprocessor:
processed_image = transform_image(image, new_image_size)
return processed_image
class PixtralVisionImagePreprocessorCompatibleReturn:
def __init__(self, pixel_values) -> None:
self.pixel_values = pixel_values
# Compatable version with ai toolkit flow
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
super().__init__(
image_patch_size=image_patch_size,
max_image_size=max_image_size
)
self.size = {
'height': max_image_size,
'width': max_image_size
}
self.image_mean = DATASET_MEAN
self.image_std = DATASET_STD
def __call__(
self,
images,
return_tensors="pt",
do_resize=True,
do_rescale=False,
) -> torch.Tensor:
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
processed_image = super().__call__(image)
out_stack.append(processed_image)
output = torch.stack(out_stack, dim=0)
return PixtralVisionImagePreprocessorCompatibleReturn(output)
class PixtralVisionEncoderCompatibleReturn:
def __init__(self, hidden_states) -> None:
self.hidden_states = hidden_states
class PixtralVisionEncoderCompatibleConfig:
def __init__(self):
self.image_size = 1024
self.hidden_size = 1024
self.patch_size = 16
class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
def __init__(
self,
hidden_size: int = 1024,
num_channels: int = 3,
image_size: int = 1024,
patch_size: int = 16,
intermediate_size: int = 4096,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
rope_theta: float = 1e4, # for rope-2D
image_token_id: int = 10,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
rope_theta=rope_theta,
image_token_id=image_token_id,
)
self.config = PixtralVisionEncoderCompatibleConfig()
def forward(
self,
images,
output_hidden_states=True,
) -> torch.Tensor:
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
# must be in an array
image_output = super().forward([image])
out_stack.append(image_output)
output = torch.stack(out_stack, dim=0)
return PixtralVisionEncoderCompatibleReturn([output])